modeling_vits.py 60 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408
  1. # coding=utf-8
  2. # Copyright 2023 The Kakao Enterprise Authors and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch VITS model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Any, Optional, Union
  19. import numpy as np
  20. import torch
  21. from torch import nn
  22. from ...activations import ACT2FN
  23. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  24. from ...integrations.fsdp import is_fsdp_managed_module
  25. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import BaseModelOutput, ModelOutput
  28. from ...modeling_utils import PreTrainedModel
  29. from ...utils import auto_docstring, logging
  30. from .configuration_vits import VitsConfig
  31. logger = logging.get_logger(__name__)
  32. @dataclass
  33. @auto_docstring(
  34. custom_intro="""
  35. Describes the outputs for the VITS model, with potential hidden states and attentions.
  36. """
  37. )
  38. class VitsModelOutput(ModelOutput):
  39. r"""
  40. waveform (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  41. The final audio waveform predicted by the model.
  42. sequence_lengths (`torch.FloatTensor` of shape `(batch_size,)`):
  43. The length in samples of each element in the `waveform` batch.
  44. spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`):
  45. The log-mel spectrogram predicted at the output of the flow model. This spectrogram is passed to the Hi-Fi
  46. GAN decoder model to obtain the final audio waveform.
  47. """
  48. waveform: Optional[torch.FloatTensor] = None
  49. sequence_lengths: Optional[torch.FloatTensor] = None
  50. spectrogram: Optional[tuple[torch.FloatTensor]] = None
  51. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  52. attentions: Optional[tuple[torch.FloatTensor]] = None
  53. @dataclass
  54. @auto_docstring(
  55. custom_intro="""
  56. Describes the outputs for the VITS text encoder model, with potential hidden states and attentions.
  57. """
  58. )
  59. class VitsTextEncoderOutput(ModelOutput):
  60. r"""
  61. prior_means (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  62. The predicted mean values of the prior distribution for the latent text variables.
  63. prior_log_variances (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  64. The predicted log-variance values of the prior distribution for the latent text variables.
  65. """
  66. last_hidden_state: Optional[torch.FloatTensor] = None
  67. prior_means: Optional[torch.FloatTensor] = None
  68. prior_log_variances: Optional[torch.FloatTensor] = None
  69. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  70. attentions: Optional[tuple[torch.FloatTensor]] = None
  71. @torch.jit.script
  72. def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels):
  73. in_act = input_a + input_b
  74. t_act = torch.tanh(in_act[:, :num_channels, :])
  75. s_act = torch.sigmoid(in_act[:, num_channels:, :])
  76. acts = t_act * s_act
  77. return acts
  78. def _unconstrained_rational_quadratic_spline(
  79. inputs,
  80. unnormalized_widths,
  81. unnormalized_heights,
  82. unnormalized_derivatives,
  83. reverse=False,
  84. tail_bound=5.0,
  85. min_bin_width=1e-3,
  86. min_bin_height=1e-3,
  87. min_derivative=1e-3,
  88. ):
  89. """
  90. This transformation represents a monotonically increasing piecewise rational quadratic function. Outside of the
  91. `tail_bound`, the transform behaves as an identity function.
  92. Args:
  93. inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
  94. Second half of the hidden-states input to the Vits convolutional flow module.
  95. unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
  96. First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
  97. layer in the convolutional flow module
  98. unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
  99. Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
  100. layer in the convolutional flow module
  101. unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
  102. Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
  103. layer in the convolutional flow module
  104. reverse (`bool`, *optional*, defaults to `False`):
  105. Whether the model is being run in reverse mode.
  106. tail_bound (`float`, *optional* defaults to 5):
  107. Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the
  108. transform behaves as an identity function.
  109. min_bin_width (`float`, *optional*, defaults to 1e-3):
  110. Minimum bin value across the width dimension for the piecewise rational quadratic function.
  111. min_bin_height (`float`, *optional*, defaults to 1e-3):
  112. Minimum bin value across the height dimension for the piecewise rational quadratic function.
  113. min_derivative (`float`, *optional*, defaults to 1e-3):
  114. Minimum bin value across the derivatives for the piecewise rational quadratic function.
  115. Returns:
  116. outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
  117. Hidden-states as transformed by the piecewise rational quadratic function with the `tail_bound` limits
  118. applied.
  119. log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
  120. Logarithm of the absolute value of the determinants corresponding to the `outputs` with the `tail_bound`
  121. limits applied.
  122. """
  123. inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
  124. outside_interval_mask = ~inside_interval_mask
  125. outputs = torch.zeros_like(inputs)
  126. log_abs_det = torch.zeros_like(inputs)
  127. constant = np.log(np.exp(1 - min_derivative) - 1)
  128. unnormalized_derivatives = nn.functional.pad(unnormalized_derivatives, pad=(1, 1))
  129. unnormalized_derivatives[..., 0] = constant
  130. unnormalized_derivatives[..., -1] = constant
  131. outputs[outside_interval_mask] = inputs[outside_interval_mask]
  132. log_abs_det[outside_interval_mask] = 0.0
  133. outputs[inside_interval_mask], log_abs_det[inside_interval_mask] = _rational_quadratic_spline(
  134. inputs=inputs[inside_interval_mask],
  135. unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
  136. unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
  137. unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
  138. reverse=reverse,
  139. tail_bound=tail_bound,
  140. min_bin_width=min_bin_width,
  141. min_bin_height=min_bin_height,
  142. min_derivative=min_derivative,
  143. )
  144. return outputs, log_abs_det
  145. def _rational_quadratic_spline(
  146. inputs,
  147. unnormalized_widths,
  148. unnormalized_heights,
  149. unnormalized_derivatives,
  150. reverse,
  151. tail_bound,
  152. min_bin_width,
  153. min_bin_height,
  154. min_derivative,
  155. ):
  156. """
  157. This transformation represents a monotonically increasing piecewise rational quadratic function. Unlike the
  158. function `_unconstrained_rational_quadratic_spline`, the function behaves the same across the `tail_bound`.
  159. Args:
  160. inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
  161. Second half of the hidden-states input to the Vits convolutional flow module.
  162. unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
  163. First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
  164. layer in the convolutional flow module
  165. unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
  166. Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
  167. layer in the convolutional flow module
  168. unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
  169. Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
  170. layer in the convolutional flow module
  171. reverse (`bool`):
  172. Whether the model is being run in reverse mode.
  173. tail_bound (`float`):
  174. Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the
  175. transform behaves as an identity function.
  176. min_bin_width (`float`):
  177. Minimum bin value across the width dimension for the piecewise rational quadratic function.
  178. min_bin_height (`float`):
  179. Minimum bin value across the height dimension for the piecewise rational quadratic function.
  180. min_derivative (`float`):
  181. Minimum bin value across the derivatives for the piecewise rational quadratic function.
  182. Returns:
  183. outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
  184. Hidden-states as transformed by the piecewise rational quadratic function.
  185. log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
  186. Logarithm of the absolute value of the determinants corresponding to the `outputs`.
  187. """
  188. upper_bound = tail_bound
  189. lower_bound = -tail_bound
  190. if torch.min(inputs) < lower_bound or torch.max(inputs) > upper_bound:
  191. raise ValueError("Input to a transform is not within its domain")
  192. num_bins = unnormalized_widths.shape[-1]
  193. if min_bin_width * num_bins > 1.0:
  194. raise ValueError(f"Minimal bin width {min_bin_width} too large for the number of bins {num_bins}")
  195. if min_bin_height * num_bins > 1.0:
  196. raise ValueError(f"Minimal bin height {min_bin_height} too large for the number of bins {num_bins}")
  197. widths = nn.functional.softmax(unnormalized_widths, dim=-1)
  198. widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
  199. cumwidths = torch.cumsum(widths, dim=-1)
  200. cumwidths = nn.functional.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
  201. cumwidths = (upper_bound - lower_bound) * cumwidths + lower_bound
  202. cumwidths[..., 0] = lower_bound
  203. cumwidths[..., -1] = upper_bound
  204. widths = cumwidths[..., 1:] - cumwidths[..., :-1]
  205. derivatives = min_derivative + nn.functional.softplus(unnormalized_derivatives)
  206. heights = nn.functional.softmax(unnormalized_heights, dim=-1)
  207. heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
  208. cumheights = torch.cumsum(heights, dim=-1)
  209. cumheights = nn.functional.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
  210. cumheights = (upper_bound - lower_bound) * cumheights + lower_bound
  211. cumheights[..., 0] = lower_bound
  212. cumheights[..., -1] = upper_bound
  213. heights = cumheights[..., 1:] - cumheights[..., :-1]
  214. bin_locations = cumheights if reverse else cumwidths
  215. bin_locations[..., -1] += 1e-6
  216. bin_idx = torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
  217. bin_idx = bin_idx[..., None]
  218. input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
  219. input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
  220. input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
  221. delta = heights / widths
  222. input_delta = delta.gather(-1, bin_idx)[..., 0]
  223. input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
  224. input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
  225. input_heights = heights.gather(-1, bin_idx)[..., 0]
  226. intermediate1 = input_derivatives + input_derivatives_plus_one - 2 * input_delta
  227. if not reverse:
  228. theta = (inputs - input_cumwidths) / input_bin_widths
  229. theta_one_minus_theta = theta * (1 - theta)
  230. numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
  231. denominator = input_delta + intermediate1 * theta_one_minus_theta
  232. outputs = input_cumheights + numerator / denominator
  233. derivative_numerator = input_delta.pow(2) * (
  234. input_derivatives_plus_one * theta.pow(2)
  235. + 2 * input_delta * theta_one_minus_theta
  236. + input_derivatives * (1 - theta).pow(2)
  237. )
  238. log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
  239. return outputs, log_abs_det
  240. else:
  241. # find the roots of a quadratic equation
  242. intermediate2 = inputs - input_cumheights
  243. intermediate3 = intermediate2 * intermediate1
  244. a = input_heights * (input_delta - input_derivatives) + intermediate3
  245. b = input_heights * input_derivatives - intermediate3
  246. c = -input_delta * intermediate2
  247. discriminant = b.pow(2) - 4 * a * c
  248. if not (discriminant >= 0).all():
  249. raise RuntimeError(f"invalid discriminant {discriminant}")
  250. root = (2 * c) / (-b - torch.sqrt(discriminant))
  251. outputs = root * input_bin_widths + input_cumwidths
  252. theta_one_minus_theta = root * (1 - root)
  253. denominator = input_delta + intermediate1 * theta_one_minus_theta
  254. derivative_numerator = input_delta.pow(2) * (
  255. input_derivatives_plus_one * root.pow(2)
  256. + 2 * input_delta * theta_one_minus_theta
  257. + input_derivatives * (1 - root).pow(2)
  258. )
  259. log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
  260. return outputs, -log_abs_det
  261. class VitsWaveNet(torch.nn.Module):
  262. def __init__(self, config: VitsConfig, num_layers: int):
  263. super().__init__()
  264. self.hidden_size = config.hidden_size
  265. self.num_layers = num_layers
  266. self.in_layers = torch.nn.ModuleList()
  267. self.res_skip_layers = torch.nn.ModuleList()
  268. self.dropout = nn.Dropout(config.wavenet_dropout)
  269. if hasattr(nn.utils.parametrizations, "weight_norm"):
  270. weight_norm = nn.utils.parametrizations.weight_norm
  271. else:
  272. weight_norm = nn.utils.weight_norm
  273. if config.speaker_embedding_size != 0:
  274. cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1)
  275. self.cond_layer = weight_norm(cond_layer, name="weight")
  276. for i in range(num_layers):
  277. dilation = config.wavenet_dilation_rate**i
  278. padding = (config.wavenet_kernel_size * dilation - dilation) // 2
  279. in_layer = torch.nn.Conv1d(
  280. in_channels=config.hidden_size,
  281. out_channels=2 * config.hidden_size,
  282. kernel_size=config.wavenet_kernel_size,
  283. dilation=dilation,
  284. padding=padding,
  285. )
  286. in_layer = weight_norm(in_layer, name="weight")
  287. self.in_layers.append(in_layer)
  288. # last one is not necessary
  289. if i < num_layers - 1:
  290. res_skip_channels = 2 * config.hidden_size
  291. else:
  292. res_skip_channels = config.hidden_size
  293. res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1)
  294. res_skip_layer = weight_norm(res_skip_layer, name="weight")
  295. self.res_skip_layers.append(res_skip_layer)
  296. def forward(self, inputs, padding_mask, global_conditioning=None):
  297. outputs = torch.zeros_like(inputs)
  298. num_channels_tensor = torch.IntTensor([self.hidden_size])
  299. if global_conditioning is not None:
  300. global_conditioning = self.cond_layer(global_conditioning)
  301. for i in range(self.num_layers):
  302. hidden_states = self.in_layers[i](inputs)
  303. if global_conditioning is not None:
  304. cond_offset = i * 2 * self.hidden_size
  305. global_states = global_conditioning[:, cond_offset : cond_offset + 2 * self.hidden_size, :]
  306. else:
  307. global_states = torch.zeros_like(hidden_states)
  308. acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0])
  309. acts = self.dropout(acts)
  310. res_skip_acts = self.res_skip_layers[i](acts)
  311. if i < self.num_layers - 1:
  312. res_acts = res_skip_acts[:, : self.hidden_size, :]
  313. inputs = (inputs + res_acts) * padding_mask
  314. outputs = outputs + res_skip_acts[:, self.hidden_size :, :]
  315. else:
  316. outputs = outputs + res_skip_acts
  317. return outputs * padding_mask
  318. def remove_weight_norm(self):
  319. if self.speaker_embedding_size != 0:
  320. torch.nn.utils.remove_weight_norm(self.cond_layer)
  321. for layer in self.in_layers:
  322. torch.nn.utils.remove_weight_norm(layer)
  323. for layer in self.res_skip_layers:
  324. torch.nn.utils.remove_weight_norm(layer)
  325. class VitsPosteriorEncoder(nn.Module):
  326. def __init__(self, config: VitsConfig):
  327. super().__init__()
  328. self.out_channels = config.flow_size
  329. self.conv_pre = nn.Conv1d(config.spectrogram_bins, config.hidden_size, 1)
  330. self.wavenet = VitsWaveNet(config, num_layers=config.posterior_encoder_num_wavenet_layers)
  331. self.conv_proj = nn.Conv1d(config.hidden_size, self.out_channels * 2, 1)
  332. def forward(self, inputs, padding_mask, global_conditioning=None):
  333. inputs = self.conv_pre(inputs) * padding_mask
  334. inputs = self.wavenet(inputs, padding_mask, global_conditioning)
  335. stats = self.conv_proj(inputs) * padding_mask
  336. mean, log_stddev = torch.split(stats, self.out_channels, dim=1)
  337. sampled = (mean + torch.randn_like(mean) * torch.exp(log_stddev)) * padding_mask
  338. return sampled, mean, log_stddev
  339. # Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock
  340. class HifiGanResidualBlock(nn.Module):
  341. def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1):
  342. super().__init__()
  343. self.leaky_relu_slope = leaky_relu_slope
  344. self.convs1 = nn.ModuleList(
  345. [
  346. nn.Conv1d(
  347. channels,
  348. channels,
  349. kernel_size,
  350. stride=1,
  351. dilation=dilation[i],
  352. padding=self.get_padding(kernel_size, dilation[i]),
  353. )
  354. for i in range(len(dilation))
  355. ]
  356. )
  357. self.convs2 = nn.ModuleList(
  358. [
  359. nn.Conv1d(
  360. channels,
  361. channels,
  362. kernel_size,
  363. stride=1,
  364. dilation=1,
  365. padding=self.get_padding(kernel_size, 1),
  366. )
  367. for _ in range(len(dilation))
  368. ]
  369. )
  370. def get_padding(self, kernel_size, dilation=1):
  371. return (kernel_size * dilation - dilation) // 2
  372. def apply_weight_norm(self):
  373. weight_norm = nn.utils.weight_norm
  374. if hasattr(nn.utils.parametrizations, "weight_norm"):
  375. weight_norm = nn.utils.parametrizations.weight_norm
  376. for layer in self.convs1:
  377. weight_norm(layer)
  378. for layer in self.convs2:
  379. weight_norm(layer)
  380. def remove_weight_norm(self):
  381. for layer in self.convs1:
  382. nn.utils.remove_weight_norm(layer)
  383. for layer in self.convs2:
  384. nn.utils.remove_weight_norm(layer)
  385. def forward(self, hidden_states):
  386. for conv1, conv2 in zip(self.convs1, self.convs2):
  387. residual = hidden_states
  388. hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
  389. hidden_states = conv1(hidden_states)
  390. hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
  391. hidden_states = conv2(hidden_states)
  392. hidden_states = hidden_states + residual
  393. return hidden_states
  394. class VitsHifiGan(nn.Module):
  395. def __init__(self, config: VitsConfig):
  396. super().__init__()
  397. self.config = config
  398. self.num_kernels = len(config.resblock_kernel_sizes)
  399. self.num_upsamples = len(config.upsample_rates)
  400. self.conv_pre = nn.Conv1d(
  401. config.flow_size,
  402. config.upsample_initial_channel,
  403. kernel_size=7,
  404. stride=1,
  405. padding=3,
  406. )
  407. self.upsampler = nn.ModuleList()
  408. for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
  409. self.upsampler.append(
  410. nn.ConvTranspose1d(
  411. config.upsample_initial_channel // (2**i),
  412. config.upsample_initial_channel // (2 ** (i + 1)),
  413. kernel_size=kernel_size,
  414. stride=upsample_rate,
  415. padding=(kernel_size - upsample_rate) // 2,
  416. )
  417. )
  418. self.resblocks = nn.ModuleList()
  419. for i in range(len(self.upsampler)):
  420. channels = config.upsample_initial_channel // (2 ** (i + 1))
  421. for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
  422. self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
  423. self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False)
  424. if config.speaker_embedding_size != 0:
  425. self.cond = nn.Conv1d(config.speaker_embedding_size, config.upsample_initial_channel, 1)
  426. def apply_weight_norm(self):
  427. weight_norm = nn.utils.weight_norm
  428. if hasattr(nn.utils.parametrizations, "weight_norm"):
  429. weight_norm = nn.utils.parametrizations.weight_norm
  430. for layer in self.upsampler:
  431. weight_norm(layer)
  432. for layer in self.resblocks:
  433. layer.apply_weight_norm()
  434. def remove_weight_norm(self):
  435. for layer in self.upsampler:
  436. nn.utils.remove_weight_norm(layer)
  437. for layer in self.resblocks:
  438. layer.remove_weight_norm()
  439. def forward(
  440. self, spectrogram: torch.FloatTensor, global_conditioning: Optional[torch.FloatTensor] = None
  441. ) -> torch.FloatTensor:
  442. r"""
  443. Converts a spectrogram into a speech waveform.
  444. Args:
  445. spectrogram (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`):
  446. Tensor containing the spectrograms.
  447. global_conditioning (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_size, 1)`, *optional*):
  448. Tensor containing speaker embeddings, for multispeaker models.
  449. Returns:
  450. `torch.FloatTensor`: Tensor of shape shape `(batch_size, 1, num_frames)` containing the speech waveform.
  451. """
  452. hidden_states = self.conv_pre(spectrogram)
  453. if global_conditioning is not None:
  454. hidden_states = hidden_states + self.cond(global_conditioning)
  455. for i in range(self.num_upsamples):
  456. hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
  457. hidden_states = self.upsampler[i](hidden_states)
  458. res_state = self.resblocks[i * self.num_kernels](hidden_states)
  459. for j in range(1, self.num_kernels):
  460. res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
  461. hidden_states = res_state / self.num_kernels
  462. hidden_states = nn.functional.leaky_relu(hidden_states)
  463. hidden_states = self.conv_post(hidden_states)
  464. waveform = torch.tanh(hidden_states)
  465. return waveform
  466. class VitsResidualCouplingLayer(nn.Module):
  467. def __init__(self, config: VitsConfig):
  468. super().__init__()
  469. self.half_channels = config.flow_size // 2
  470. self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1)
  471. self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers)
  472. self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1)
  473. def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
  474. first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
  475. hidden_states = self.conv_pre(first_half) * padding_mask
  476. hidden_states = self.wavenet(hidden_states, padding_mask, global_conditioning)
  477. mean = self.conv_post(hidden_states) * padding_mask
  478. log_stddev = torch.zeros_like(mean)
  479. if not reverse:
  480. second_half = mean + second_half * torch.exp(log_stddev) * padding_mask
  481. outputs = torch.cat([first_half, second_half], dim=1)
  482. log_determinant = torch.sum(log_stddev, [1, 2])
  483. return outputs, log_determinant
  484. else:
  485. second_half = (second_half - mean) * torch.exp(-log_stddev) * padding_mask
  486. outputs = torch.cat([first_half, second_half], dim=1)
  487. return outputs, None
  488. class VitsResidualCouplingBlock(nn.Module):
  489. def __init__(self, config: VitsConfig):
  490. super().__init__()
  491. self.flows = nn.ModuleList()
  492. for _ in range(config.prior_encoder_num_flows):
  493. self.flows.append(VitsResidualCouplingLayer(config))
  494. def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
  495. if not reverse:
  496. for flow in self.flows:
  497. inputs, _ = flow(inputs, padding_mask, global_conditioning)
  498. inputs = torch.flip(inputs, [1])
  499. else:
  500. for flow in reversed(self.flows):
  501. inputs = torch.flip(inputs, [1])
  502. inputs, _ = flow(inputs, padding_mask, global_conditioning, reverse=True)
  503. return inputs
  504. class VitsDilatedDepthSeparableConv(nn.Module):
  505. def __init__(self, config: VitsConfig, dropout_rate=0.0):
  506. super().__init__()
  507. kernel_size = config.duration_predictor_kernel_size
  508. channels = config.hidden_size
  509. self.num_layers = config.depth_separable_num_layers
  510. self.dropout = nn.Dropout(dropout_rate)
  511. self.convs_dilated = nn.ModuleList()
  512. self.convs_pointwise = nn.ModuleList()
  513. self.norms_1 = nn.ModuleList()
  514. self.norms_2 = nn.ModuleList()
  515. for i in range(self.num_layers):
  516. dilation = kernel_size**i
  517. padding = (kernel_size * dilation - dilation) // 2
  518. self.convs_dilated.append(
  519. nn.Conv1d(
  520. in_channels=channels,
  521. out_channels=channels,
  522. kernel_size=kernel_size,
  523. groups=channels,
  524. dilation=dilation,
  525. padding=padding,
  526. )
  527. )
  528. self.convs_pointwise.append(nn.Conv1d(channels, channels, 1))
  529. self.norms_1.append(nn.LayerNorm(channels))
  530. self.norms_2.append(nn.LayerNorm(channels))
  531. def forward(self, inputs, padding_mask, global_conditioning=None):
  532. if global_conditioning is not None:
  533. inputs = inputs + global_conditioning
  534. for i in range(self.num_layers):
  535. hidden_states = self.convs_dilated[i](inputs * padding_mask)
  536. hidden_states = self.norms_1[i](hidden_states.transpose(1, -1)).transpose(1, -1)
  537. hidden_states = nn.functional.gelu(hidden_states)
  538. hidden_states = self.convs_pointwise[i](hidden_states)
  539. hidden_states = self.norms_2[i](hidden_states.transpose(1, -1)).transpose(1, -1)
  540. hidden_states = nn.functional.gelu(hidden_states)
  541. hidden_states = self.dropout(hidden_states)
  542. inputs = inputs + hidden_states
  543. return inputs * padding_mask
  544. class VitsConvFlow(nn.Module):
  545. def __init__(self, config: VitsConfig):
  546. super().__init__()
  547. self.filter_channels = config.hidden_size
  548. self.half_channels = config.depth_separable_channels // 2
  549. self.num_bins = config.duration_predictor_flow_bins
  550. self.tail_bound = config.duration_predictor_tail_bound
  551. self.conv_pre = nn.Conv1d(self.half_channels, self.filter_channels, 1)
  552. self.conv_dds = VitsDilatedDepthSeparableConv(config)
  553. self.conv_proj = nn.Conv1d(self.filter_channels, self.half_channels * (self.num_bins * 3 - 1), 1)
  554. def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
  555. first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
  556. hidden_states = self.conv_pre(first_half)
  557. hidden_states = self.conv_dds(hidden_states, padding_mask, global_conditioning)
  558. hidden_states = self.conv_proj(hidden_states) * padding_mask
  559. batch_size, channels, length = first_half.shape
  560. hidden_states = hidden_states.reshape(batch_size, channels, -1, length).permute(0, 1, 3, 2)
  561. unnormalized_widths = hidden_states[..., : self.num_bins] / math.sqrt(self.filter_channels)
  562. unnormalized_heights = hidden_states[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
  563. unnormalized_derivatives = hidden_states[..., 2 * self.num_bins :]
  564. second_half, log_abs_det = _unconstrained_rational_quadratic_spline(
  565. second_half,
  566. unnormalized_widths,
  567. unnormalized_heights,
  568. unnormalized_derivatives,
  569. reverse=reverse,
  570. tail_bound=self.tail_bound,
  571. )
  572. outputs = torch.cat([first_half, second_half], dim=1) * padding_mask
  573. if not reverse:
  574. log_determinant = torch.sum(log_abs_det * padding_mask, [1, 2])
  575. return outputs, log_determinant
  576. else:
  577. return outputs, None
  578. class VitsElementwiseAffine(nn.Module):
  579. def __init__(self, config: VitsConfig):
  580. super().__init__()
  581. self.channels = config.depth_separable_channels
  582. self.translate = nn.Parameter(torch.zeros(self.channels, 1))
  583. self.log_scale = nn.Parameter(torch.zeros(self.channels, 1))
  584. def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
  585. if not reverse:
  586. outputs = self.translate + torch.exp(self.log_scale) * inputs
  587. outputs = outputs * padding_mask
  588. log_determinant = torch.sum(self.log_scale * padding_mask, [1, 2])
  589. return outputs, log_determinant
  590. else:
  591. outputs = (inputs - self.translate) * torch.exp(-self.log_scale) * padding_mask
  592. return outputs, None
  593. class VitsStochasticDurationPredictor(nn.Module):
  594. def __init__(self, config):
  595. super().__init__()
  596. embed_dim = config.speaker_embedding_size
  597. filter_channels = config.hidden_size
  598. self.conv_pre = nn.Conv1d(filter_channels, filter_channels, 1)
  599. self.conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
  600. self.conv_dds = VitsDilatedDepthSeparableConv(
  601. config,
  602. dropout_rate=config.duration_predictor_dropout,
  603. )
  604. if embed_dim != 0:
  605. self.cond = nn.Conv1d(embed_dim, filter_channels, 1)
  606. self.flows = nn.ModuleList()
  607. self.flows.append(VitsElementwiseAffine(config))
  608. for _ in range(config.duration_predictor_num_flows):
  609. self.flows.append(VitsConvFlow(config))
  610. self.post_conv_pre = nn.Conv1d(1, filter_channels, 1)
  611. self.post_conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
  612. self.post_conv_dds = VitsDilatedDepthSeparableConv(
  613. config,
  614. dropout_rate=config.duration_predictor_dropout,
  615. )
  616. self.post_flows = nn.ModuleList()
  617. self.post_flows.append(VitsElementwiseAffine(config))
  618. for _ in range(config.duration_predictor_num_flows):
  619. self.post_flows.append(VitsConvFlow(config))
  620. def forward(self, inputs, padding_mask, global_conditioning=None, durations=None, reverse=False, noise_scale=1.0):
  621. inputs = torch.detach(inputs)
  622. inputs = self.conv_pre(inputs)
  623. if global_conditioning is not None:
  624. global_conditioning = torch.detach(global_conditioning)
  625. inputs = inputs + self.cond(global_conditioning)
  626. inputs = self.conv_dds(inputs, padding_mask)
  627. inputs = self.conv_proj(inputs) * padding_mask
  628. if not reverse:
  629. hidden_states = self.post_conv_pre(durations)
  630. hidden_states = self.post_conv_dds(hidden_states, padding_mask)
  631. hidden_states = self.post_conv_proj(hidden_states) * padding_mask
  632. random_posterior = (
  633. torch.randn(durations.size(0), 2, durations.size(2)).to(device=inputs.device, dtype=inputs.dtype)
  634. * padding_mask
  635. )
  636. log_determinant_posterior_sum = 0
  637. latents_posterior = random_posterior
  638. for flow in self.post_flows:
  639. latents_posterior, log_determinant = flow(
  640. latents_posterior, padding_mask, global_conditioning=inputs + hidden_states
  641. )
  642. latents_posterior = torch.flip(latents_posterior, [1])
  643. log_determinant_posterior_sum += log_determinant
  644. first_half, second_half = torch.split(latents_posterior, [1, 1], dim=1)
  645. log_determinant_posterior_sum += torch.sum(
  646. (nn.functional.logsigmoid(first_half) + nn.functional.logsigmoid(-first_half)) * padding_mask, [1, 2]
  647. )
  648. logq = (
  649. torch.sum(-0.5 * (math.log(2 * math.pi) + (random_posterior**2)) * padding_mask, [1, 2])
  650. - log_determinant_posterior_sum
  651. )
  652. first_half = (durations - torch.sigmoid(first_half)) * padding_mask
  653. first_half = torch.log(torch.clamp_min(first_half, 1e-5)) * padding_mask
  654. log_determinant_sum = torch.sum(-first_half, [1, 2])
  655. latents = torch.cat([first_half, second_half], dim=1)
  656. for flow in self.flows:
  657. latents, log_determinant = flow(latents, padding_mask, global_conditioning=inputs)
  658. latents = torch.flip(latents, [1])
  659. log_determinant_sum += log_determinant
  660. nll = torch.sum(0.5 * (math.log(2 * math.pi) + (latents**2)) * padding_mask, [1, 2]) - log_determinant_sum
  661. return nll + logq
  662. else:
  663. flows = list(reversed(self.flows))
  664. flows = flows[:-2] + [flows[-1]] # remove a useless vflow
  665. latents = (
  666. torch.randn(inputs.size(0), 2, inputs.size(2)).to(device=inputs.device, dtype=inputs.dtype)
  667. * noise_scale
  668. )
  669. for flow in flows:
  670. latents = torch.flip(latents, [1])
  671. latents, _ = flow(latents, padding_mask, global_conditioning=inputs, reverse=True)
  672. log_duration, _ = torch.split(latents, [1, 1], dim=1)
  673. return log_duration
  674. class VitsDurationPredictor(nn.Module):
  675. def __init__(self, config):
  676. super().__init__()
  677. kernel_size = config.duration_predictor_kernel_size
  678. filter_channels = config.duration_predictor_filter_channels
  679. self.dropout = nn.Dropout(config.duration_predictor_dropout)
  680. self.conv_1 = nn.Conv1d(config.hidden_size, filter_channels, kernel_size, padding=kernel_size // 2)
  681. self.norm_1 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps)
  682. self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
  683. self.norm_2 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps)
  684. self.proj = nn.Conv1d(filter_channels, 1, 1)
  685. if config.speaker_embedding_size != 0:
  686. self.cond = nn.Conv1d(config.speaker_embedding_size, config.hidden_size, 1)
  687. def forward(self, inputs, padding_mask, global_conditioning=None):
  688. inputs = torch.detach(inputs)
  689. if global_conditioning is not None:
  690. global_conditioning = torch.detach(global_conditioning)
  691. inputs = inputs + self.cond(global_conditioning)
  692. inputs = self.conv_1(inputs * padding_mask)
  693. inputs = torch.relu(inputs)
  694. inputs = self.norm_1(inputs.transpose(1, -1)).transpose(1, -1)
  695. inputs = self.dropout(inputs)
  696. inputs = self.conv_2(inputs * padding_mask)
  697. inputs = torch.relu(inputs)
  698. inputs = self.norm_2(inputs.transpose(1, -1)).transpose(1, -1)
  699. inputs = self.dropout(inputs)
  700. inputs = self.proj(inputs * padding_mask)
  701. return inputs * padding_mask
  702. class VitsAttention(nn.Module):
  703. """Multi-headed attention with relative positional representation."""
  704. def __init__(self, config: VitsConfig):
  705. super().__init__()
  706. self.embed_dim = config.hidden_size
  707. self.num_heads = config.num_attention_heads
  708. self.dropout = config.attention_dropout
  709. self.window_size = config.window_size
  710. self.head_dim = self.embed_dim // self.num_heads
  711. self.scaling = self.head_dim**-0.5
  712. if (self.head_dim * self.num_heads) != self.embed_dim:
  713. raise ValueError(
  714. f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.embed_dim}"
  715. f" and `num_attention_heads`: {self.num_heads})."
  716. )
  717. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
  718. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
  719. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
  720. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
  721. if self.window_size:
  722. self.emb_rel_k = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling)
  723. self.emb_rel_v = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling)
  724. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  725. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  726. def forward(
  727. self,
  728. hidden_states: torch.Tensor,
  729. key_value_states: Optional[torch.Tensor] = None,
  730. attention_mask: Optional[torch.Tensor] = None,
  731. layer_head_mask: Optional[torch.Tensor] = None,
  732. output_attentions: bool = False,
  733. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  734. """Input shape: Batch x Time x Channel"""
  735. # if key_value_states are provided this layer is used as a cross-attention layer
  736. # for the decoder
  737. bsz, tgt_len, _ = hidden_states.size()
  738. # get query proj
  739. query_states = self.q_proj(hidden_states) * self.scaling
  740. # self_attention
  741. key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
  742. value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
  743. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  744. query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
  745. key_states = key_states.view(*proj_shape)
  746. value_states = value_states.view(*proj_shape)
  747. src_len = key_states.size(1)
  748. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  749. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  750. raise ValueError(
  751. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  752. f" {attn_weights.size()}"
  753. )
  754. if self.window_size is not None:
  755. key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, src_len)
  756. relative_logits = torch.matmul(query_states, key_relative_embeddings.transpose(-2, -1))
  757. rel_pos_bias = self._relative_position_to_absolute_position(relative_logits)
  758. attn_weights += rel_pos_bias
  759. if attention_mask is not None:
  760. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  761. raise ValueError(
  762. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  763. )
  764. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  765. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  766. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  767. if layer_head_mask is not None:
  768. if layer_head_mask.size() != (self.num_heads,):
  769. raise ValueError(
  770. f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
  771. f" {layer_head_mask.size()}"
  772. )
  773. attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  774. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  775. if output_attentions:
  776. # this operation is a bit awkward, but it's required to
  777. # make sure that attn_weights keeps its gradient.
  778. # In order to do so, attn_weights have to be reshaped
  779. # twice and have to be reused in the following
  780. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  781. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  782. else:
  783. attn_weights_reshaped = None
  784. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  785. attn_output = torch.bmm(attn_probs, value_states)
  786. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  787. raise ValueError(
  788. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  789. f" {attn_output.size()}"
  790. )
  791. if self.window_size is not None:
  792. value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, src_len)
  793. relative_weights = self._absolute_position_to_relative_position(attn_probs)
  794. rel_pos_bias = torch.matmul(relative_weights, value_relative_embeddings)
  795. attn_output += rel_pos_bias
  796. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  797. attn_output = attn_output.transpose(1, 2)
  798. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  799. # partitioned across GPUs when using tensor-parallelism.
  800. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  801. attn_output = self.out_proj(attn_output)
  802. return attn_output, attn_weights_reshaped
  803. def _get_relative_embeddings(self, relative_embeddings, length):
  804. pad_length = max(length - (self.window_size + 1), 0)
  805. if pad_length > 0:
  806. relative_embeddings = nn.functional.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0])
  807. slice_start_position = max((self.window_size + 1) - length, 0)
  808. slice_end_position = slice_start_position + 2 * length - 1
  809. return relative_embeddings[:, slice_start_position:slice_end_position]
  810. def _relative_position_to_absolute_position(self, x):
  811. batch_heads, length, _ = x.size()
  812. # Concat columns of pad to shift from relative to absolute indexing.
  813. x = nn.functional.pad(x, [0, 1, 0, 0, 0, 0])
  814. # Concat extra elements so to add up to shape (len+1, 2*len-1).
  815. x_flat = x.view([batch_heads, length * 2 * length])
  816. x_flat = nn.functional.pad(x_flat, [0, length - 1, 0, 0])
  817. # Reshape and slice out the padded elements.
  818. x_final = x_flat.view([batch_heads, length + 1, 2 * length - 1])
  819. x_final = x_final[:, :length, length - 1 :]
  820. return x_final
  821. def _absolute_position_to_relative_position(self, x):
  822. batch_heads, length, _ = x.size()
  823. # Pad along column
  824. x = nn.functional.pad(x, [0, length - 1, 0, 0, 0, 0])
  825. x_flat = x.view([batch_heads, length * (2 * length - 1)])
  826. # Add 0's in the beginning that will skew the elements after reshape
  827. x_flat = nn.functional.pad(x_flat, [length, 0, 0, 0])
  828. x_final = x_flat.view([batch_heads, length, 2 * length])[:, :, 1:]
  829. return x_final
  830. class VitsFeedForward(nn.Module):
  831. def __init__(self, config):
  832. super().__init__()
  833. self.conv_1 = nn.Conv1d(config.hidden_size, config.ffn_dim, config.ffn_kernel_size)
  834. self.conv_2 = nn.Conv1d(config.ffn_dim, config.hidden_size, config.ffn_kernel_size)
  835. self.dropout = nn.Dropout(config.activation_dropout)
  836. if isinstance(config.hidden_act, str):
  837. self.act_fn = ACT2FN[config.hidden_act]
  838. else:
  839. self.act_fn = config.hidden_act
  840. if config.ffn_kernel_size > 1:
  841. pad_left = (config.ffn_kernel_size - 1) // 2
  842. pad_right = config.ffn_kernel_size // 2
  843. self.padding = [pad_left, pad_right, 0, 0, 0, 0]
  844. else:
  845. self.padding = None
  846. def forward(self, hidden_states, padding_mask):
  847. hidden_states = hidden_states.permute(0, 2, 1)
  848. padding_mask = padding_mask.permute(0, 2, 1)
  849. hidden_states = hidden_states * padding_mask
  850. if self.padding is not None:
  851. hidden_states = nn.functional.pad(hidden_states, self.padding)
  852. hidden_states = self.conv_1(hidden_states)
  853. hidden_states = self.act_fn(hidden_states)
  854. hidden_states = self.dropout(hidden_states)
  855. hidden_states = hidden_states * padding_mask
  856. if self.padding is not None:
  857. hidden_states = nn.functional.pad(hidden_states, self.padding)
  858. hidden_states = self.conv_2(hidden_states)
  859. hidden_states = hidden_states * padding_mask
  860. hidden_states = hidden_states.permute(0, 2, 1)
  861. return hidden_states
  862. class VitsEncoderLayer(GradientCheckpointingLayer):
  863. def __init__(self, config: VitsConfig):
  864. super().__init__()
  865. self.attention = VitsAttention(config)
  866. self.dropout = nn.Dropout(config.hidden_dropout)
  867. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  868. self.feed_forward = VitsFeedForward(config)
  869. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  870. def forward(
  871. self,
  872. hidden_states: torch.Tensor,
  873. padding_mask: torch.FloatTensor,
  874. attention_mask: Optional[torch.Tensor] = None,
  875. output_attentions: bool = False,
  876. ):
  877. residual = hidden_states
  878. hidden_states, attn_weights = self.attention(
  879. hidden_states=hidden_states,
  880. attention_mask=attention_mask,
  881. output_attentions=output_attentions,
  882. )
  883. hidden_states = self.dropout(hidden_states)
  884. hidden_states = self.layer_norm(residual + hidden_states)
  885. residual = hidden_states
  886. hidden_states = self.feed_forward(hidden_states, padding_mask)
  887. hidden_states = self.dropout(hidden_states)
  888. hidden_states = self.final_layer_norm(residual + hidden_states)
  889. outputs = (hidden_states,)
  890. if output_attentions:
  891. outputs += (attn_weights,)
  892. return outputs
  893. class VitsEncoder(nn.Module):
  894. def __init__(self, config: VitsConfig):
  895. super().__init__()
  896. self.config = config
  897. self.layers = nn.ModuleList([VitsEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  898. self.gradient_checkpointing = False
  899. self.layerdrop = config.layerdrop
  900. def forward(
  901. self,
  902. hidden_states: torch.FloatTensor,
  903. padding_mask: torch.FloatTensor,
  904. attention_mask: Optional[torch.Tensor] = None,
  905. output_attentions: Optional[bool] = None,
  906. output_hidden_states: Optional[bool] = None,
  907. return_dict: Optional[bool] = None,
  908. ) -> Union[tuple, BaseModelOutput]:
  909. all_hidden_states = () if output_hidden_states else None
  910. all_self_attentions = () if output_attentions else None
  911. # expand attention_mask
  912. if attention_mask is not None:
  913. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  914. attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
  915. hidden_states = hidden_states * padding_mask
  916. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  917. for encoder_layer in self.layers:
  918. if output_hidden_states:
  919. all_hidden_states = all_hidden_states + (hidden_states,)
  920. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  921. dropout_probability = np.random.uniform(0, 1)
  922. skip_the_layer = self.training and (dropout_probability < self.layerdrop)
  923. if not skip_the_layer or synced_gpus:
  924. # under fsdp or deepspeed zero3 all gpus must run in sync
  925. layer_outputs = encoder_layer(
  926. hidden_states,
  927. attention_mask=attention_mask,
  928. padding_mask=padding_mask,
  929. output_attentions=output_attentions,
  930. )
  931. hidden_states = layer_outputs[0]
  932. if skip_the_layer:
  933. layer_outputs = (None, None)
  934. if output_attentions:
  935. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  936. hidden_states = hidden_states * padding_mask
  937. if output_hidden_states:
  938. all_hidden_states = all_hidden_states + (hidden_states,)
  939. if not return_dict:
  940. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  941. return BaseModelOutput(
  942. last_hidden_state=hidden_states,
  943. hidden_states=all_hidden_states,
  944. attentions=all_self_attentions,
  945. )
  946. class VitsTextEncoder(nn.Module):
  947. """
  948. Transformer encoder that uses relative positional representation instead of absolute positional encoding.
  949. """
  950. def __init__(self, config: VitsConfig):
  951. super().__init__()
  952. self.config = config
  953. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
  954. self.encoder = VitsEncoder(config)
  955. self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1)
  956. def forward(
  957. self,
  958. input_ids: torch.Tensor,
  959. padding_mask: torch.FloatTensor,
  960. attention_mask: Optional[torch.Tensor] = None,
  961. output_attentions: Optional[bool] = None,
  962. output_hidden_states: Optional[bool] = None,
  963. return_dict: Optional[bool] = True,
  964. ) -> Union[tuple[torch.Tensor], VitsTextEncoderOutput]:
  965. hidden_states = self.embed_tokens(input_ids) * math.sqrt(self.config.hidden_size)
  966. encoder_outputs = self.encoder(
  967. hidden_states=hidden_states,
  968. padding_mask=padding_mask,
  969. attention_mask=attention_mask,
  970. output_attentions=output_attentions,
  971. output_hidden_states=output_hidden_states,
  972. return_dict=return_dict,
  973. )
  974. last_hidden_state = encoder_outputs[0] if not return_dict else encoder_outputs.last_hidden_state
  975. stats = self.project(last_hidden_state.transpose(1, 2)).transpose(1, 2) * padding_mask
  976. prior_means, prior_log_variances = torch.split(stats, self.config.flow_size, dim=2)
  977. if not return_dict:
  978. outputs = (last_hidden_state, prior_means, prior_log_variances) + encoder_outputs[1:]
  979. return outputs
  980. return VitsTextEncoderOutput(
  981. last_hidden_state=last_hidden_state,
  982. prior_means=prior_means,
  983. prior_log_variances=prior_log_variances,
  984. hidden_states=encoder_outputs.hidden_states,
  985. attentions=encoder_outputs.attentions,
  986. )
  987. @auto_docstring
  988. class VitsPreTrainedModel(PreTrainedModel):
  989. config: VitsConfig
  990. base_model_prefix = "vits"
  991. main_input_name = "input_ids"
  992. supports_gradient_checkpointing = True
  993. def _init_weights(self, module: nn.Module):
  994. """Initialize the weights"""
  995. std = self.config.initializer_range
  996. if isinstance(module, nn.Linear):
  997. module.weight.data.normal_(mean=0.0, std=std)
  998. if module.bias is not None:
  999. module.bias.data.zero_()
  1000. elif isinstance(module, nn.LayerNorm):
  1001. module.bias.data.zero_()
  1002. module.weight.data.fill_(1.0)
  1003. elif isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)):
  1004. nn.init.kaiming_normal_(module.weight)
  1005. if module.bias is not None:
  1006. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  1007. nn.init.uniform_(module.bias, a=-k, b=k)
  1008. elif isinstance(module, nn.Embedding):
  1009. module.weight.data.normal_(mean=0.0, std=std)
  1010. if module.padding_idx is not None:
  1011. module.weight.data[module.padding_idx].zero_()
  1012. elif isinstance(module, VitsAttention):
  1013. if self.config.window_size:
  1014. head_dim = self.config.hidden_size // self.config.num_attention_heads
  1015. nn.init.normal_(module.emb_rel_k, std=head_dim**-0.5)
  1016. nn.init.normal_(module.emb_rel_v, std=head_dim**-0.5)
  1017. elif isinstance(module, VitsElementwiseAffine):
  1018. module.translate.data.zero_()
  1019. module.log_scale.data.zero_()
  1020. @auto_docstring(
  1021. custom_intro="""
  1022. The complete VITS model, for text-to-speech synthesis.
  1023. """
  1024. )
  1025. class VitsModel(VitsPreTrainedModel):
  1026. def __init__(self, config: VitsConfig):
  1027. super().__init__(config)
  1028. self.config = config
  1029. self.text_encoder = VitsTextEncoder(config)
  1030. self.flow = VitsResidualCouplingBlock(config)
  1031. self.decoder = VitsHifiGan(config)
  1032. if config.use_stochastic_duration_prediction:
  1033. self.duration_predictor = VitsStochasticDurationPredictor(config)
  1034. else:
  1035. self.duration_predictor = VitsDurationPredictor(config)
  1036. if config.num_speakers > 1:
  1037. self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size)
  1038. # This is used only for training.
  1039. self.posterior_encoder = VitsPosteriorEncoder(config)
  1040. # These parameters control the synthesised speech properties
  1041. self.speaking_rate = config.speaking_rate
  1042. self.noise_scale = config.noise_scale
  1043. self.noise_scale_duration = config.noise_scale_duration
  1044. # Initialize weights and apply final processing
  1045. self.post_init()
  1046. def get_encoder(self):
  1047. return self.text_encoder
  1048. @auto_docstring
  1049. def forward(
  1050. self,
  1051. input_ids: Optional[torch.Tensor] = None,
  1052. attention_mask: Optional[torch.Tensor] = None,
  1053. speaker_id: Optional[int] = None,
  1054. output_attentions: Optional[bool] = None,
  1055. output_hidden_states: Optional[bool] = None,
  1056. return_dict: Optional[bool] = None,
  1057. labels: Optional[torch.FloatTensor] = None,
  1058. ) -> Union[tuple[Any], VitsModelOutput]:
  1059. r"""
  1060. speaker_id (`int`, *optional*):
  1061. Which speaker embedding to use. Only used for multispeaker models.
  1062. labels (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`, *optional*):
  1063. Float values of target spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss
  1064. computation.
  1065. Example:
  1066. ```python
  1067. >>> from transformers import VitsTokenizer, VitsModel, set_seed
  1068. >>> import torch
  1069. >>> tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
  1070. >>> model = VitsModel.from_pretrained("facebook/mms-tts-eng")
  1071. >>> inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt")
  1072. >>> set_seed(555) # make deterministic
  1073. >>> with torch.no_grad():
  1074. ... outputs = model(inputs["input_ids"])
  1075. >>> outputs.waveform.shape
  1076. torch.Size([1, 45824])
  1077. ```
  1078. """
  1079. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1080. output_hidden_states = (
  1081. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1082. )
  1083. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1084. if labels is not None:
  1085. raise NotImplementedError("Training of VITS is not supported yet.")
  1086. mask_dtype = self.text_encoder.embed_tokens.weight.dtype
  1087. if attention_mask is not None:
  1088. input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
  1089. else:
  1090. input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
  1091. if self.config.num_speakers > 1 and speaker_id is not None:
  1092. if not 0 <= speaker_id < self.config.num_speakers:
  1093. raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.")
  1094. if isinstance(speaker_id, int):
  1095. speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
  1096. speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1)
  1097. else:
  1098. speaker_embeddings = None
  1099. text_encoder_output = self.text_encoder(
  1100. input_ids=input_ids,
  1101. padding_mask=input_padding_mask,
  1102. attention_mask=attention_mask,
  1103. output_attentions=output_attentions,
  1104. output_hidden_states=output_hidden_states,
  1105. return_dict=return_dict,
  1106. )
  1107. hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
  1108. hidden_states = hidden_states.transpose(1, 2)
  1109. input_padding_mask = input_padding_mask.transpose(1, 2)
  1110. prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means
  1111. prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances
  1112. if self.config.use_stochastic_duration_prediction:
  1113. log_duration = self.duration_predictor(
  1114. hidden_states,
  1115. input_padding_mask,
  1116. speaker_embeddings,
  1117. reverse=True,
  1118. noise_scale=self.noise_scale_duration,
  1119. )
  1120. else:
  1121. log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
  1122. length_scale = 1.0 / self.speaking_rate
  1123. duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale)
  1124. predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
  1125. # Create a padding mask for the output lengths of shape (batch, 1, max_output_length)
  1126. indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
  1127. output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
  1128. output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
  1129. # Reconstruct an attention tensor of shape (batch, 1, out_length, in_length)
  1130. attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
  1131. batch_size, _, output_length, input_length = attn_mask.shape
  1132. cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
  1133. indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
  1134. valid_indices = indices.unsqueeze(0) < cum_duration
  1135. valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
  1136. padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
  1137. attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
  1138. # Expand prior distribution
  1139. prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2)
  1140. prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2)
  1141. prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
  1142. latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True)
  1143. spectrogram = latents * output_padding_mask
  1144. waveform = self.decoder(spectrogram, speaker_embeddings)
  1145. waveform = waveform.squeeze(1)
  1146. sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates)
  1147. if not return_dict:
  1148. outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:]
  1149. return outputs
  1150. return VitsModelOutput(
  1151. waveform=waveform,
  1152. sequence_lengths=sequence_lengths,
  1153. spectrogram=spectrogram,
  1154. hidden_states=text_encoder_output.hidden_states,
  1155. attentions=text_encoder_output.attentions,
  1156. )
  1157. __all__ = ["VitsModel", "VitsPreTrainedModel"]