modeling_speecht5.py 144 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242
  1. # coding=utf-8
  2. # Copyright 2023 The Fairseq Authors, Microsoft Research, and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch SpeechT5 model."""
  16. import math
  17. from typing import Optional, Union
  18. import numpy as np
  19. import torch
  20. from torch import nn
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, L1Loss
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  24. from ...generation import GenerationMixin
  25. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  26. from ...integrations.fsdp import is_fsdp_managed_module
  27. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
  28. from ...modeling_layers import GradientCheckpointingLayer
  29. from ...modeling_outputs import (
  30. BaseModelOutput,
  31. BaseModelOutputWithPastAndCrossAttentions,
  32. Seq2SeqLMOutput,
  33. Seq2SeqModelOutput,
  34. Seq2SeqSpectrogramOutput,
  35. )
  36. from ...modeling_utils import EmbeddingAccessMixin, PreTrainedModel
  37. from ...utils import auto_docstring, logging
  38. from ...utils.deprecation import deprecate_kwarg
  39. from .configuration_speecht5 import SpeechT5Config, SpeechT5HifiGanConfig
  40. logger = logging.get_logger(__name__)
  41. _HIDDEN_STATES_START_POSITION = 1
  42. # General docstring
  43. # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
  44. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  45. """
  46. Shift input ids one token to the right.
  47. """
  48. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  49. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  50. shifted_input_ids[:, 0] = decoder_start_token_id
  51. if pad_token_id is None:
  52. raise ValueError("self.model.config.pad_token_id has to be defined.")
  53. # replace possible -100 values in labels by `pad_token_id`
  54. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  55. return shifted_input_ids
  56. def shift_spectrograms_right(
  57. input_values: torch.Tensor, reduction_factor: int = 1, attention_mask: Optional[torch.Tensor] = None
  58. ):
  59. """
  60. Shift input spectrograms one timestep to the right. Also applies the reduction factor to the sequence length.
  61. """
  62. # thin out frames for reduction factor
  63. if reduction_factor > 1:
  64. input_values = input_values[:, reduction_factor - 1 :: reduction_factor]
  65. if attention_mask is not None:
  66. attention_mask = attention_mask[:, reduction_factor - 1 :: reduction_factor]
  67. shifted_input_values = input_values.new_zeros(input_values.shape)
  68. shifted_input_values[:, 1:] = input_values[:, :-1].clone()
  69. # replace possible -100 values in labels by zeros
  70. shifted_input_values.masked_fill_(shifted_input_values == -100.0, 0.0)
  71. return shifted_input_values, attention_mask
  72. # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
  73. def _compute_mask_indices(
  74. shape: tuple[int, int],
  75. mask_prob: float,
  76. mask_length: int,
  77. attention_mask: Optional[torch.LongTensor] = None,
  78. min_masks: int = 0,
  79. ) -> np.ndarray:
  80. """
  81. Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
  82. ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
  83. CPU as part of the preprocessing during training.
  84. Args:
  85. shape: The shape for which to compute masks. This should be of a tuple of size 2 where
  86. the first element is the batch size and the second element is the length of the axis to span.
  87. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
  88. independently generated mask spans of length `mask_length` is computed by
  89. `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
  90. actual percentage will be smaller.
  91. mask_length: size of the mask
  92. min_masks: minimum number of masked spans
  93. attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
  94. each batch dimension.
  95. """
  96. batch_size, sequence_length = shape
  97. if mask_length < 1:
  98. raise ValueError("`mask_length` has to be bigger than 0.")
  99. if mask_length > sequence_length:
  100. raise ValueError(
  101. f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
  102. f" and `sequence_length`: {sequence_length}`"
  103. )
  104. # epsilon is used for probabilistic rounding
  105. epsilon = np.random.rand(1).item()
  106. def compute_num_masked_span(input_length):
  107. """Given input length, compute how many spans should be masked"""
  108. num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
  109. num_masked_span = max(num_masked_span, min_masks)
  110. # make sure num masked span <= sequence_length
  111. if num_masked_span * mask_length > sequence_length:
  112. num_masked_span = sequence_length // mask_length
  113. # make sure num_masked span is also <= input_length - (mask_length - 1)
  114. if input_length - (mask_length - 1) < num_masked_span:
  115. num_masked_span = max(input_length - (mask_length - 1), 0)
  116. return num_masked_span
  117. # compute number of masked spans in batch
  118. input_lengths = (
  119. attention_mask.detach().sum(-1).tolist()
  120. if attention_mask is not None
  121. else [sequence_length for _ in range(batch_size)]
  122. )
  123. # SpecAugment mask to fill
  124. spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
  125. spec_aug_mask_idxs = []
  126. max_num_masked_span = compute_num_masked_span(sequence_length)
  127. if max_num_masked_span == 0:
  128. return spec_aug_mask
  129. for input_length in input_lengths:
  130. # compute num of masked spans for this input
  131. num_masked_span = compute_num_masked_span(input_length)
  132. # get random indices to mask
  133. spec_aug_mask_idx = np.random.choice(
  134. np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
  135. )
  136. # pick first sampled index that will serve as a dummy index to pad vector
  137. # to ensure same dimension for all batches due to probabilistic rounding
  138. # Picking first sample just pads those vectors twice.
  139. if len(spec_aug_mask_idx) == 0:
  140. # this case can only happen if `input_length` is strictly smaller then
  141. # `sequence_length` in which case the last token has to be a padding
  142. # token which we can use as a dummy mask id
  143. dummy_mask_idx = sequence_length - 1
  144. else:
  145. dummy_mask_idx = spec_aug_mask_idx[0]
  146. spec_aug_mask_idx = np.concatenate(
  147. [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
  148. )
  149. spec_aug_mask_idxs.append(spec_aug_mask_idx)
  150. spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
  151. # expand masked indices to masked spans
  152. spec_aug_mask_idxs = np.broadcast_to(
  153. spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
  154. )
  155. spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
  156. # add offset to the starting indexes so that indexes now create a span
  157. offsets = np.arange(mask_length)[None, None, :]
  158. offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
  159. batch_size, max_num_masked_span * mask_length
  160. )
  161. spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
  162. # ensure that we cannot have indices larger than sequence_length
  163. if spec_aug_mask_idxs.max() > sequence_length - 1:
  164. spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
  165. # scatter indices to mask
  166. np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
  167. return spec_aug_mask
  168. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SpeechT5
  169. class SpeechT5NoLayerNormConvLayer(GradientCheckpointingLayer):
  170. def __init__(self, config, layer_id=0):
  171. super().__init__()
  172. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  173. self.out_conv_dim = config.conv_dim[layer_id]
  174. self.conv = nn.Conv1d(
  175. self.in_conv_dim,
  176. self.out_conv_dim,
  177. kernel_size=config.conv_kernel[layer_id],
  178. stride=config.conv_stride[layer_id],
  179. bias=config.conv_bias,
  180. )
  181. self.activation = ACT2FN[config.feat_extract_activation]
  182. def forward(self, hidden_states):
  183. hidden_states = self.conv(hidden_states)
  184. hidden_states = self.activation(hidden_states)
  185. return hidden_states
  186. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SpeechT5
  187. class SpeechT5LayerNormConvLayer(GradientCheckpointingLayer):
  188. def __init__(self, config, layer_id=0):
  189. super().__init__()
  190. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  191. self.out_conv_dim = config.conv_dim[layer_id]
  192. self.conv = nn.Conv1d(
  193. self.in_conv_dim,
  194. self.out_conv_dim,
  195. kernel_size=config.conv_kernel[layer_id],
  196. stride=config.conv_stride[layer_id],
  197. bias=config.conv_bias,
  198. )
  199. self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
  200. self.activation = ACT2FN[config.feat_extract_activation]
  201. def forward(self, hidden_states):
  202. hidden_states = self.conv(hidden_states)
  203. hidden_states = hidden_states.transpose(-2, -1)
  204. hidden_states = self.layer_norm(hidden_states)
  205. hidden_states = hidden_states.transpose(-2, -1)
  206. hidden_states = self.activation(hidden_states)
  207. return hidden_states
  208. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SpeechT5
  209. class SpeechT5GroupNormConvLayer(GradientCheckpointingLayer):
  210. def __init__(self, config, layer_id=0):
  211. super().__init__()
  212. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  213. self.out_conv_dim = config.conv_dim[layer_id]
  214. self.conv = nn.Conv1d(
  215. self.in_conv_dim,
  216. self.out_conv_dim,
  217. kernel_size=config.conv_kernel[layer_id],
  218. stride=config.conv_stride[layer_id],
  219. bias=config.conv_bias,
  220. )
  221. self.activation = ACT2FN[config.feat_extract_activation]
  222. self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
  223. def forward(self, hidden_states):
  224. hidden_states = self.conv(hidden_states)
  225. hidden_states = self.layer_norm(hidden_states)
  226. hidden_states = self.activation(hidden_states)
  227. return hidden_states
  228. # Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextSinusoidalPositionalEmbedding with Speech2Text->SpeechT5
  229. class SpeechT5SinusoidalPositionalEmbedding(nn.Module):
  230. """This module produces sinusoidal positional embeddings of any length."""
  231. def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
  232. super().__init__()
  233. self.offset = 2
  234. self.embedding_dim = embedding_dim
  235. self.padding_idx = padding_idx
  236. self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
  237. def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
  238. emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
  239. if hasattr(self, "weights"):
  240. # in forward put the weights on the correct dtype and device of the param
  241. emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
  242. self.register_buffer("weights", emb_weights, persistent=False)
  243. @staticmethod
  244. def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
  245. """
  246. Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the
  247. description in Section 3.5 of "Attention Is All You Need".
  248. """
  249. half_dim = embedding_dim // 2
  250. emb = math.log(10000) / (half_dim - 1)
  251. emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
  252. emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
  253. emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
  254. if embedding_dim % 2 == 1:
  255. # zero pad
  256. emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
  257. if padding_idx is not None:
  258. emb[padding_idx, :] = 0
  259. return emb.to(torch.get_default_dtype())
  260. @torch.no_grad()
  261. def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
  262. bsz, seq_len = input_ids.size()
  263. # Create the position ids from the input token ids. Any padded tokens remain padded.
  264. position_ids = self.create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length).to(
  265. input_ids.device
  266. )
  267. # expand embeddings if needed
  268. max_pos = self.padding_idx + 1 + seq_len
  269. if max_pos > self.weights.size(0):
  270. self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
  271. return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, -1).detach()
  272. def create_position_ids_from_input_ids(
  273. self, input_ids: torch.Tensor, padding_idx: int, past_key_values_length: Optional[int] = 0
  274. ):
  275. """
  276. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
  277. symbols are ignored. This is modified from fairseq's `utils.make_positions`.
  278. Args:
  279. x: torch.Tensor x:
  280. Returns: torch.Tensor
  281. """
  282. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  283. mask = input_ids.ne(padding_idx).int()
  284. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  285. return incremental_indices.long() + padding_idx
  286. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->SpeechT5
  287. class SpeechT5PositionalConvEmbedding(nn.Module):
  288. def __init__(self, config):
  289. super().__init__()
  290. self.conv = nn.Conv1d(
  291. config.hidden_size,
  292. config.hidden_size,
  293. kernel_size=config.num_conv_pos_embeddings,
  294. padding=config.num_conv_pos_embeddings // 2,
  295. groups=config.num_conv_pos_embedding_groups,
  296. )
  297. weight_norm = nn.utils.weight_norm
  298. if hasattr(nn.utils.parametrizations, "weight_norm"):
  299. weight_norm = nn.utils.parametrizations.weight_norm
  300. if is_deepspeed_zero3_enabled():
  301. import deepspeed
  302. with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
  303. self.conv = weight_norm(self.conv, name="weight", dim=2)
  304. if hasattr(self.conv, "parametrizations"):
  305. weight_g = self.conv.parametrizations.weight.original0
  306. weight_v = self.conv.parametrizations.weight.original1
  307. else:
  308. weight_g = self.conv.weight_g
  309. weight_v = self.conv.weight_v
  310. deepspeed.zero.register_external_parameter(self, weight_v)
  311. deepspeed.zero.register_external_parameter(self, weight_g)
  312. else:
  313. self.conv = weight_norm(self.conv, name="weight", dim=2)
  314. self.padding = SpeechT5SamePadLayer(config.num_conv_pos_embeddings)
  315. self.activation = ACT2FN[config.feat_extract_activation]
  316. def forward(self, hidden_states):
  317. hidden_states = hidden_states.transpose(1, 2)
  318. hidden_states = self.conv(hidden_states)
  319. hidden_states = self.padding(hidden_states)
  320. hidden_states = self.activation(hidden_states)
  321. hidden_states = hidden_states.transpose(1, 2)
  322. return hidden_states
  323. class SpeechT5ScaledPositionalEncoding(nn.Module):
  324. """
  325. Scaled positional encoding, see §3.2 in https://huggingface.co/papers/1809.08895
  326. """
  327. def __init__(self, dropout, dim, max_len=5000):
  328. pe = torch.zeros(max_len, dim)
  329. position = torch.arange(0, max_len).unsqueeze(1)
  330. div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.int64).float() * -(math.log(10000.0) / dim))
  331. pe[:, 0::2] = torch.sin(position.float() * div_term)
  332. pe[:, 1::2] = torch.cos(position.float() * div_term)
  333. pe = pe.unsqueeze(0)
  334. super().__init__()
  335. self.register_buffer("pe", pe, persistent=False)
  336. self.dropout = nn.Dropout(p=dropout)
  337. self.dim = dim
  338. self.alpha = nn.Parameter(torch.tensor(1.0))
  339. def forward(self, emb):
  340. emb = emb + self.alpha * self.pe[:, : emb.size(1)]
  341. emb = self.dropout(emb)
  342. return emb
  343. class SpeechT5RelativePositionalEncoding(torch.nn.Module):
  344. def __init__(self, dim, max_length=1000):
  345. super().__init__()
  346. self.dim = dim
  347. self.max_length = max_length
  348. self.pe_k = torch.nn.Embedding(2 * max_length, dim)
  349. def forward(self, hidden_states):
  350. seq_len = hidden_states.shape[1]
  351. pos_seq = torch.arange(0, seq_len).to(device=hidden_states.device, dtype=torch.long)
  352. pos_seq = pos_seq[:, None] - pos_seq[None, :]
  353. pos_seq[pos_seq < -self.max_length] = -self.max_length
  354. pos_seq[pos_seq >= self.max_length] = self.max_length - 1
  355. pos_seq = pos_seq + self.max_length
  356. return self.pe_k(pos_seq)
  357. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->SpeechT5
  358. class SpeechT5SamePadLayer(nn.Module):
  359. def __init__(self, num_conv_pos_embeddings):
  360. super().__init__()
  361. self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
  362. def forward(self, hidden_states):
  363. if self.num_pad_remove > 0:
  364. hidden_states = hidden_states[:, :, : -self.num_pad_remove]
  365. return hidden_states
  366. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->SpeechT5
  367. class SpeechT5FeatureEncoder(nn.Module):
  368. """Construct the features from raw audio waveform"""
  369. def __init__(self, config):
  370. super().__init__()
  371. if config.feat_extract_norm == "group":
  372. conv_layers = [SpeechT5GroupNormConvLayer(config, layer_id=0)] + [
  373. SpeechT5NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
  374. ]
  375. elif config.feat_extract_norm == "layer":
  376. conv_layers = [
  377. SpeechT5LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
  378. ]
  379. else:
  380. raise ValueError(
  381. f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
  382. )
  383. self.conv_layers = nn.ModuleList(conv_layers)
  384. self.gradient_checkpointing = False
  385. self._requires_grad = True
  386. def _freeze_parameters(self):
  387. for param in self.parameters():
  388. param.requires_grad = False
  389. self._requires_grad = False
  390. def forward(self, input_values):
  391. hidden_states = input_values[:, None]
  392. # make sure hidden_states require grad for gradient_checkpointing
  393. if self._requires_grad and self.training:
  394. hidden_states.requires_grad = True
  395. for conv_layer in self.conv_layers:
  396. hidden_states = conv_layer(hidden_states)
  397. return hidden_states
  398. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->SpeechT5
  399. class SpeechT5FeatureProjection(nn.Module):
  400. def __init__(self, config):
  401. super().__init__()
  402. self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
  403. self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
  404. self.dropout = nn.Dropout(config.feat_proj_dropout)
  405. def forward(self, hidden_states):
  406. # non-projected hidden states are needed for quantization
  407. norm_hidden_states = self.layer_norm(hidden_states)
  408. hidden_states = self.projection(norm_hidden_states)
  409. hidden_states = self.dropout(hidden_states)
  410. return hidden_states, norm_hidden_states
  411. class SpeechT5SpeechEncoderPrenet(nn.Module):
  412. def __init__(self, config):
  413. super().__init__()
  414. self.config = config
  415. self.feature_encoder = SpeechT5FeatureEncoder(config)
  416. self.feature_projection = SpeechT5FeatureProjection(config)
  417. # model only needs masking vector if mask prob is > 0.0
  418. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  419. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  420. self.pos_conv_embed = SpeechT5PositionalConvEmbedding(config)
  421. self.pos_sinusoidal_embed = SpeechT5SinusoidalPositionalEmbedding(
  422. config.max_speech_positions + config.pad_token_id + 1,
  423. config.hidden_size,
  424. config.pad_token_id,
  425. )
  426. def freeze_feature_encoder(self):
  427. self.feature_encoder._freeze_parameters()
  428. def forward(
  429. self,
  430. input_values: torch.Tensor,
  431. attention_mask: Optional[torch.LongTensor] = None,
  432. mask_time_indices: Optional[torch.FloatTensor] = None,
  433. ):
  434. extract_features = self.feature_encoder(input_values)
  435. extract_features = extract_features.transpose(1, 2)
  436. if attention_mask is not None:
  437. # compute reduced attention_mask corresponding to feature vectors
  438. attention_mask = self._get_feature_vector_attention_mask(
  439. extract_features.shape[1],
  440. attention_mask,
  441. )
  442. hidden_states, extract_features = self.feature_projection(extract_features)
  443. hidden_states = self._mask_hidden_states(
  444. hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
  445. )
  446. positional_conv_embedding = self.pos_conv_embed(hidden_states)
  447. hidden_states = hidden_states + positional_conv_embedding
  448. if attention_mask is not None:
  449. padding_mask = attention_mask.ne(1).long()
  450. else:
  451. padding_mask = torch.zeros(hidden_states.shape[:2], dtype=torch.long, device=hidden_states.device)
  452. positional_sinusoidal_embeddings = self.pos_sinusoidal_embed(padding_mask)
  453. hidden_states = hidden_states + positional_sinusoidal_embeddings
  454. return hidden_states, attention_mask
  455. # Copied from transformers.models.unispeech.modeling_unispeech.UniSpeechPreTrainedModel._get_feature_vector_attention_mask
  456. def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
  457. # Effectively attention_mask.sum(-1), but not inplace to be able to run
  458. # on inference mode.
  459. non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
  460. output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
  461. batch_size = attention_mask.shape[0]
  462. attention_mask = torch.zeros(
  463. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  464. )
  465. # these two operations makes sure that all values before the output lengths idxs are attended to
  466. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  467. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  468. return attention_mask
  469. # Copied from transformers.models.unispeech.modeling_unispeech.UniSpeechPreTrainedModel._get_feat_extract_output_lengths
  470. def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
  471. """
  472. Computes the output length of the convolutional layers
  473. """
  474. def _conv_out_length(input_length, kernel_size, stride):
  475. # 1D convolutional layer output length formula taken
  476. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  477. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  478. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  479. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  480. return input_lengths
  481. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
  482. def _mask_hidden_states(
  483. self,
  484. hidden_states: torch.FloatTensor,
  485. mask_time_indices: Optional[torch.FloatTensor] = None,
  486. attention_mask: Optional[torch.LongTensor] = None,
  487. ):
  488. """
  489. Masks extracted features along time axis and/or along feature axis according to
  490. [SpecAugment](https://huggingface.co/papers/1904.08779).
  491. """
  492. # `config.apply_spec_augment` can set masking to False
  493. if not getattr(self.config, "apply_spec_augment", True):
  494. return hidden_states
  495. # generate indices & apply SpecAugment along time axis
  496. batch_size, sequence_length, hidden_size = hidden_states.size()
  497. if mask_time_indices is not None:
  498. # apply SpecAugment along time axis with given mask_time_indices
  499. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  500. elif self.config.mask_time_prob > 0 and self.training:
  501. mask_time_indices = _compute_mask_indices(
  502. (batch_size, sequence_length),
  503. mask_prob=self.config.mask_time_prob,
  504. mask_length=self.config.mask_time_length,
  505. attention_mask=attention_mask,
  506. min_masks=self.config.mask_time_min_masks,
  507. )
  508. mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
  509. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  510. if self.config.mask_feature_prob > 0 and self.training:
  511. # generate indices & apply SpecAugment along feature axis
  512. mask_feature_indices = _compute_mask_indices(
  513. (batch_size, hidden_size),
  514. mask_prob=self.config.mask_feature_prob,
  515. mask_length=self.config.mask_feature_length,
  516. min_masks=self.config.mask_feature_min_masks,
  517. )
  518. mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
  519. mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
  520. hidden_states[mask_feature_indices] = 0
  521. return hidden_states
  522. class SpeechT5SpeechDecoderPrenet(nn.Module):
  523. def __init__(self, config):
  524. super().__init__()
  525. self.config = config
  526. self.layers = nn.ModuleList(
  527. [
  528. nn.Linear(
  529. config.num_mel_bins if i == 0 else config.speech_decoder_prenet_units,
  530. config.speech_decoder_prenet_units,
  531. )
  532. for i in range(config.speech_decoder_prenet_layers)
  533. ]
  534. )
  535. self.final_layer = nn.Linear(config.speech_decoder_prenet_units, config.hidden_size)
  536. self.encode_positions = SpeechT5ScaledPositionalEncoding(
  537. config.positional_dropout,
  538. config.hidden_size,
  539. config.max_speech_positions,
  540. )
  541. self.speaker_embeds_layer = nn.Linear(config.speaker_embedding_dim + config.hidden_size, config.hidden_size)
  542. def _consistent_dropout(self, inputs_embeds, p):
  543. mask = torch.bernoulli(inputs_embeds[0], p=p)
  544. all_masks = mask.unsqueeze(0).repeat(inputs_embeds.size(0), 1, 1)
  545. return torch.where(all_masks == 1, inputs_embeds, 0) * 1 / (1 - p)
  546. def forward(
  547. self,
  548. input_values: torch.Tensor,
  549. speaker_embeddings: Optional[torch.Tensor] = None,
  550. ):
  551. # Dropout is always applied, even when evaluating. See §2.2 in https://huggingface.co/papers/1712.05884.
  552. inputs_embeds = input_values
  553. for layer in self.layers:
  554. inputs_embeds = nn.functional.relu(layer(inputs_embeds))
  555. inputs_embeds = self._consistent_dropout(inputs_embeds, self.config.speech_decoder_prenet_dropout)
  556. inputs_embeds = self.final_layer(inputs_embeds)
  557. inputs_embeds = self.encode_positions(inputs_embeds)
  558. if speaker_embeddings is not None:
  559. speaker_embeddings = nn.functional.normalize(speaker_embeddings)
  560. speaker_embeddings = speaker_embeddings.unsqueeze(1).expand(-1, inputs_embeds.size(1), -1)
  561. inputs_embeds = torch.cat([inputs_embeds, speaker_embeddings], dim=-1)
  562. inputs_embeds = nn.functional.relu(self.speaker_embeds_layer(inputs_embeds))
  563. return inputs_embeds
  564. class SpeechT5BatchNormConvLayer(nn.Module):
  565. def __init__(self, config, layer_id=0):
  566. super().__init__()
  567. if layer_id == 0:
  568. in_conv_dim = config.num_mel_bins
  569. else:
  570. in_conv_dim = config.speech_decoder_postnet_units
  571. if layer_id == config.speech_decoder_postnet_layers - 1:
  572. out_conv_dim = config.num_mel_bins
  573. else:
  574. out_conv_dim = config.speech_decoder_postnet_units
  575. self.conv = nn.Conv1d(
  576. in_conv_dim,
  577. out_conv_dim,
  578. kernel_size=config.speech_decoder_postnet_kernel,
  579. stride=1,
  580. padding=(config.speech_decoder_postnet_kernel - 1) // 2,
  581. bias=False,
  582. )
  583. self.batch_norm = nn.BatchNorm1d(out_conv_dim)
  584. if layer_id < config.speech_decoder_postnet_layers - 1:
  585. self.activation = nn.Tanh()
  586. else:
  587. self.activation = None
  588. self.dropout = nn.Dropout(config.speech_decoder_postnet_dropout)
  589. def forward(self, hidden_states):
  590. hidden_states = self.conv(hidden_states)
  591. hidden_states = self.batch_norm(hidden_states)
  592. if self.activation is not None:
  593. hidden_states = self.activation(hidden_states)
  594. hidden_states = self.dropout(hidden_states)
  595. return hidden_states
  596. class SpeechT5SpeechDecoderPostnet(nn.Module):
  597. def __init__(self, config):
  598. super().__init__()
  599. self.config = config
  600. self.feat_out = nn.Linear(config.hidden_size, config.num_mel_bins * config.reduction_factor)
  601. self.prob_out = nn.Linear(config.hidden_size, config.reduction_factor)
  602. self.layers = nn.ModuleList(
  603. [SpeechT5BatchNormConvLayer(config, i) for i in range(config.speech_decoder_postnet_layers)]
  604. )
  605. def forward(self, hidden_states: torch.Tensor):
  606. outputs_before_postnet = self.feat_out(hidden_states).view(hidden_states.size(0), -1, self.config.num_mel_bins)
  607. outputs_after_postnet = self.postnet(outputs_before_postnet)
  608. logits = self.prob_out(hidden_states).view(hidden_states.size(0), -1)
  609. return outputs_before_postnet, outputs_after_postnet, logits
  610. def postnet(self, hidden_states: torch.Tensor):
  611. layer_output = hidden_states.transpose(1, 2)
  612. for layer in self.layers:
  613. layer_output = layer(layer_output)
  614. return hidden_states + layer_output.transpose(1, 2)
  615. class SpeechT5TextEncoderPrenet(nn.Module, EmbeddingAccessMixin):
  616. def __init__(self, config):
  617. super().__init__()
  618. self.config = config
  619. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
  620. self.encode_positions = SpeechT5ScaledPositionalEncoding(
  621. config.positional_dropout,
  622. config.hidden_size,
  623. config.max_text_positions,
  624. )
  625. def forward(self, input_ids: torch.Tensor):
  626. inputs_embeds = self.embed_tokens(input_ids)
  627. inputs_embeds = self.encode_positions(inputs_embeds)
  628. return inputs_embeds
  629. class SpeechT5TextDecoderPrenet(nn.Module, EmbeddingAccessMixin):
  630. def __init__(self, config):
  631. super().__init__()
  632. self.config = config
  633. self.dropout = nn.Dropout(config.positional_dropout)
  634. self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
  635. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
  636. self.embed_positions = SpeechT5SinusoidalPositionalEmbedding(
  637. config.max_text_positions + config.pad_token_id + 1,
  638. config.hidden_size,
  639. config.pad_token_id,
  640. )
  641. def forward(
  642. self,
  643. input_ids: torch.Tensor,
  644. attention_mask: Optional[torch.LongTensor] = None,
  645. past_key_values: Optional[Cache] = None,
  646. ):
  647. if input_ids is not None:
  648. input_shape = input_ids.size()
  649. input_ids = input_ids.view(-1, input_shape[-1])
  650. else:
  651. raise ValueError("You have to specify `decoder_input_ids`")
  652. past_key_values_length = 0
  653. if past_key_values is not None:
  654. past_key_values_length = (
  655. past_key_values[0][0].shape[-2]
  656. if not isinstance(past_key_values, Cache)
  657. else past_key_values.get_seq_length()
  658. )
  659. positions = self.embed_positions(input_ids, past_key_values_length)
  660. inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
  661. inputs_embeds += positions
  662. inputs_embeds = self.dropout(inputs_embeds)
  663. return inputs_embeds, attention_mask
  664. class SpeechT5TextDecoderPostnet(nn.Module, EmbeddingAccessMixin):
  665. def __init__(self, config):
  666. super().__init__()
  667. self.config = config
  668. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  669. def forward(self, hidden_states: torch.Tensor):
  670. return self.lm_head(hidden_states)
  671. def get_output_embeddings(self):
  672. # Post-net has no token embeddings, but its lm_head must still be
  673. # tied to the decoder weights when `tie_word_embeddings=True`.
  674. return self.lm_head
  675. def set_output_embeddings(self, new_embeddings):
  676. self.lm_head = new_embeddings
  677. class SpeechT5Attention(nn.Module):
  678. """
  679. Multi-headed attention from 'Attention Is All You Need' paper with relative position bias (see
  680. https://aclanthology.org/N18-2074.pdf)
  681. """
  682. def __init__(
  683. self,
  684. embed_dim: int,
  685. num_heads: int,
  686. dropout: Optional[float] = 0.0,
  687. is_decoder: Optional[bool] = False,
  688. bias: Optional[bool] = True,
  689. layer_idx: Optional[bool] = None,
  690. ):
  691. super().__init__()
  692. self.embed_dim = embed_dim
  693. self.num_heads = num_heads
  694. self.dropout = dropout
  695. self.head_dim = embed_dim // num_heads
  696. if (self.head_dim * num_heads) != self.embed_dim:
  697. raise ValueError(
  698. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  699. f" and `num_heads`: {num_heads})."
  700. )
  701. self.scaling = self.head_dim**-0.5
  702. self.is_decoder = is_decoder
  703. self.layer_idx = layer_idx
  704. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  705. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  706. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  707. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  708. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  709. def forward(
  710. self,
  711. hidden_states: torch.Tensor,
  712. key_value_states: Optional[torch.Tensor] = None,
  713. past_key_values: Optional[Cache] = None,
  714. attention_mask: Optional[torch.Tensor] = None,
  715. layer_head_mask: Optional[torch.Tensor] = None,
  716. position_bias: Optional[torch.Tensor] = None,
  717. output_attentions: bool = False,
  718. cache_position: Optional[torch.Tensor] = None,
  719. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
  720. """Input shape: Batch x Time x Channel"""
  721. # if key_value_states are provided this layer is used as a cross-attention layer
  722. # for the decoder
  723. is_cross_attention = key_value_states is not None
  724. bsz, tgt_len, _ = hidden_states.size()
  725. # get query proj
  726. query_states = self.q_proj(hidden_states) * self.scaling
  727. is_updated = False
  728. if past_key_values is not None:
  729. if isinstance(past_key_values, EncoderDecoderCache):
  730. is_updated = past_key_values.is_updated.get(self.layer_idx)
  731. if is_cross_attention:
  732. # after the first generated id, we can subsequently re-use all key/value_states from cache
  733. curr_past_key_value = past_key_values.cross_attention_cache
  734. else:
  735. curr_past_key_value = past_key_values.self_attention_cache
  736. else:
  737. curr_past_key_value = past_key_values
  738. current_states = key_value_states if is_cross_attention else hidden_states
  739. if is_cross_attention and past_key_values is not None and is_updated:
  740. # reuse k,v, cross_attentions
  741. key_states = curr_past_key_value.layers[self.layer_idx].keys
  742. value_states = curr_past_key_value.layers[self.layer_idx].values
  743. else:
  744. key_states = self.k_proj(current_states)
  745. value_states = self.v_proj(current_states)
  746. key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  747. value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
  748. if past_key_values is not None:
  749. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  750. cache_position = cache_position if not is_cross_attention else None
  751. key_states, value_states = curr_past_key_value.update(
  752. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  753. )
  754. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  755. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  756. past_key_values.is_updated[self.layer_idx] = True
  757. proj_shape = (bsz * self.num_heads, -1, self.head_dim)
  758. query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2)
  759. query_states = query_states.reshape(*proj_shape)
  760. key_states = key_states.reshape(*proj_shape)
  761. value_states = value_states.reshape(*proj_shape)
  762. src_len = key_states.size(1)
  763. attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
  764. if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  765. raise ValueError(
  766. f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
  767. f" {attn_weights.size()}"
  768. )
  769. # relative attention bias
  770. if position_bias is not None:
  771. reshape_q = query_states.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0, 1)
  772. rel_pos_bias = torch.matmul(reshape_q, position_bias.transpose(-2, -1))
  773. rel_pos_bias = rel_pos_bias.transpose(0, 1).view(
  774. bsz * self.num_heads, position_bias.size(0), position_bias.size(1)
  775. )
  776. attn_weights += rel_pos_bias
  777. if attention_mask is not None:
  778. if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  779. raise ValueError(
  780. f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
  781. )
  782. attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
  783. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  784. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  785. if layer_head_mask is not None:
  786. if layer_head_mask.size() != (self.num_heads,):
  787. raise ValueError(
  788. f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
  789. f" {layer_head_mask.size()}"
  790. )
  791. attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  792. attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
  793. if output_attentions:
  794. # this operation is a bit awkward, but it's required to
  795. # make sure that attn_weights keeps its gradient.
  796. # In order to do so, attn_weights have to be reshaped
  797. # twice and have to be reused in the following
  798. attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
  799. attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
  800. else:
  801. attn_weights_reshaped = None
  802. attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  803. attn_output = torch.bmm(attn_probs, value_states)
  804. if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
  805. raise ValueError(
  806. f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
  807. f" {attn_output.size()}"
  808. )
  809. attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
  810. attn_output = attn_output.transpose(1, 2)
  811. # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
  812. # partitioned across GPUs when using tensor-parallelism.
  813. attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
  814. attn_output = self.out_proj(attn_output)
  815. return attn_output, attn_weights_reshaped
  816. class SpeechT5FeedForward(nn.Module):
  817. def __init__(self, config, intermediate_size):
  818. super().__init__()
  819. self.intermediate_dropout = nn.Dropout(config.activation_dropout)
  820. self.intermediate_dense = nn.Linear(config.hidden_size, intermediate_size)
  821. if isinstance(config.hidden_act, str):
  822. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  823. else:
  824. self.intermediate_act_fn = config.hidden_act
  825. self.output_dense = nn.Linear(intermediate_size, config.hidden_size)
  826. self.output_dropout = nn.Dropout(config.hidden_dropout)
  827. def forward(self, hidden_states):
  828. hidden_states = self.intermediate_dense(hidden_states)
  829. hidden_states = self.intermediate_act_fn(hidden_states)
  830. hidden_states = self.intermediate_dropout(hidden_states)
  831. hidden_states = self.output_dense(hidden_states)
  832. hidden_states = self.output_dropout(hidden_states)
  833. return hidden_states
  834. class SpeechT5EncoderLayer(GradientCheckpointingLayer):
  835. def __init__(self, config: SpeechT5Config):
  836. super().__init__()
  837. self.attention = SpeechT5Attention(
  838. embed_dim=config.hidden_size,
  839. num_heads=config.encoder_attention_heads,
  840. dropout=config.attention_dropout,
  841. is_decoder=False,
  842. )
  843. self.dropout = nn.Dropout(config.hidden_dropout)
  844. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  845. self.feed_forward = SpeechT5FeedForward(config, config.encoder_ffn_dim)
  846. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  847. def forward(
  848. self,
  849. hidden_states: torch.Tensor,
  850. attention_mask: Optional[torch.Tensor] = None,
  851. layer_head_mask: Optional[torch.Tensor] = None,
  852. position_bias: Optional[torch.Tensor] = None,
  853. output_attentions: bool = False,
  854. ):
  855. """
  856. Args:
  857. hidden_states (`torch.FloatTensor`):
  858. input to the layer of shape `(batch, seq_len, hidden_size)`
  859. attention_mask (`torch.FloatTensor`):
  860. attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very
  861. large negative values.
  862. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  863. `(config.encoder_attention_heads,)`.
  864. position_bias (`torch.FloatTensor`):
  865. relative position embeddings of size `(seq_len, seq_len, hidden_size // encoder_attention_heads)`
  866. output_attentions (`bool`, *optional*):
  867. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  868. returned tensors for more detail.
  869. """
  870. residual = hidden_states
  871. hidden_states, attn_weights = self.attention(
  872. hidden_states=hidden_states,
  873. attention_mask=attention_mask,
  874. layer_head_mask=layer_head_mask,
  875. position_bias=position_bias,
  876. output_attentions=output_attentions,
  877. )
  878. hidden_states = self.dropout(hidden_states)
  879. hidden_states = residual + hidden_states
  880. hidden_states = self.layer_norm(hidden_states)
  881. hidden_states = hidden_states + self.feed_forward(hidden_states)
  882. hidden_states = self.final_layer_norm(hidden_states)
  883. outputs = (hidden_states,)
  884. if output_attentions:
  885. outputs += (attn_weights,)
  886. return outputs
  887. class SpeechT5DecoderLayer(GradientCheckpointingLayer):
  888. def __init__(self, config: SpeechT5Config, layer_idx=None):
  889. super().__init__()
  890. self.self_attn = SpeechT5Attention(
  891. embed_dim=config.hidden_size,
  892. num_heads=config.decoder_attention_heads,
  893. dropout=config.attention_dropout,
  894. is_decoder=True,
  895. layer_idx=layer_idx,
  896. )
  897. self.dropout = nn.Dropout(config.hidden_dropout)
  898. self.self_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  899. self.encoder_attn = SpeechT5Attention(
  900. config.hidden_size,
  901. config.decoder_attention_heads,
  902. dropout=config.attention_dropout,
  903. is_decoder=True,
  904. layer_idx=layer_idx,
  905. )
  906. self.encoder_attn_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  907. self.feed_forward = SpeechT5FeedForward(config, config.decoder_ffn_dim)
  908. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  909. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  910. def forward(
  911. self,
  912. hidden_states: torch.Tensor,
  913. attention_mask: Optional[torch.Tensor] = None,
  914. encoder_hidden_states: Optional[torch.Tensor] = None,
  915. encoder_attention_mask: Optional[torch.Tensor] = None,
  916. layer_head_mask: Optional[torch.Tensor] = None,
  917. cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
  918. past_key_values: Optional[Cache] = None,
  919. output_attentions: Optional[bool] = False,
  920. use_cache: Optional[bool] = True,
  921. cache_position: Optional[torch.Tensor] = None,
  922. ):
  923. """
  924. Args:
  925. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
  926. attention_mask (`torch.FloatTensor`): attention mask of size
  927. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  928. encoder_hidden_states (`torch.FloatTensor`):
  929. cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
  930. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  931. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  932. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  933. `(encoder_attention_heads,)`.
  934. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
  935. size `(decoder_attention_heads,)`.
  936. past_key_values (`Cache`): cached past key and value projection states
  937. output_attentions (`bool`, *optional*):
  938. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  939. returned tensors for more detail.
  940. """
  941. residual = hidden_states
  942. # Self Attention
  943. hidden_states, self_attn_weights = self.self_attn(
  944. hidden_states=hidden_states,
  945. past_key_values=past_key_values,
  946. attention_mask=attention_mask,
  947. layer_head_mask=layer_head_mask,
  948. output_attentions=output_attentions,
  949. cache_position=cache_position,
  950. )
  951. hidden_states = self.dropout(hidden_states)
  952. hidden_states = residual + hidden_states
  953. hidden_states = self.self_attn_layer_norm(hidden_states)
  954. # Cross-Attention Block
  955. cross_attn_weights = None
  956. if encoder_hidden_states is not None:
  957. residual = hidden_states
  958. hidden_states, cross_attn_weights = self.encoder_attn(
  959. hidden_states=hidden_states,
  960. key_value_states=encoder_hidden_states,
  961. attention_mask=encoder_attention_mask,
  962. layer_head_mask=cross_attn_layer_head_mask,
  963. past_key_values=past_key_values,
  964. output_attentions=output_attentions,
  965. cache_position=cache_position,
  966. )
  967. hidden_states = self.dropout(hidden_states)
  968. hidden_states = residual + hidden_states
  969. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  970. # Fully Connected
  971. hidden_states = hidden_states + self.feed_forward(hidden_states)
  972. hidden_states = self.final_layer_norm(hidden_states)
  973. outputs = (hidden_states,)
  974. if output_attentions:
  975. outputs += (self_attn_weights, cross_attn_weights)
  976. return outputs
  977. @auto_docstring
  978. class SpeechT5PreTrainedModel(PreTrainedModel):
  979. config: SpeechT5Config
  980. base_model_prefix = "speecht5"
  981. main_input_name = "input_values"
  982. supports_gradient_checkpointing = True
  983. def _init_weights(self, module: nn.Module):
  984. """Initialize the weights"""
  985. std = self.config.initializer_range
  986. if isinstance(module, SpeechT5PositionalConvEmbedding):
  987. nn.init.normal_(
  988. module.conv.weight,
  989. mean=0,
  990. std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
  991. )
  992. nn.init.constant_(module.conv.bias, 0)
  993. elif isinstance(module, SpeechT5ScaledPositionalEncoding):
  994. module.alpha.data.fill_(1.0)
  995. elif isinstance(module, SpeechT5FeatureProjection):
  996. k = math.sqrt(1 / module.projection.in_features)
  997. nn.init.uniform_(module.projection.weight, a=-k, b=k)
  998. nn.init.uniform_(module.projection.bias, a=-k, b=k)
  999. elif isinstance(module, nn.Linear):
  1000. module.weight.data.normal_(mean=0.0, std=std)
  1001. if module.bias is not None:
  1002. module.bias.data.zero_()
  1003. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)):
  1004. module.bias.data.zero_()
  1005. module.weight.data.fill_(1.0)
  1006. elif isinstance(module, nn.Conv1d):
  1007. nn.init.kaiming_normal_(module.weight)
  1008. if module.bias is not None:
  1009. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  1010. nn.init.uniform_(module.bias, a=-k, b=k)
  1011. elif isinstance(module, nn.Embedding):
  1012. module.weight.data.normal_(mean=0.0, std=std)
  1013. if module.padding_idx is not None:
  1014. module.weight.data[module.padding_idx].zero_()
  1015. if hasattr(module, "masked_spec_embed"):
  1016. nn.init.uniform_(module.masked_spec_embed)
  1017. class SpeechT5Encoder(SpeechT5PreTrainedModel):
  1018. """
  1019. Transformer encoder consisting of *config.encoder_layers* layers. Each layer is a [`SpeechT5EncoderLayer`].
  1020. """
  1021. def __init__(self, config: SpeechT5Config):
  1022. super().__init__(config)
  1023. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  1024. self.dropout = nn.Dropout(config.hidden_dropout)
  1025. self.layerdrop = config.encoder_layerdrop
  1026. self.layers = nn.ModuleList([SpeechT5EncoderLayer(config) for _ in range(config.encoder_layers)])
  1027. self.embed_positions = SpeechT5RelativePositionalEncoding(
  1028. config.hidden_size // config.encoder_attention_heads, config.encoder_max_relative_position
  1029. )
  1030. self.gradient_checkpointing = False
  1031. # Initialize weights and apply final processing
  1032. self.post_init()
  1033. def forward(
  1034. self,
  1035. hidden_states: torch.FloatTensor,
  1036. attention_mask: Optional[torch.Tensor] = None,
  1037. head_mask: Optional[torch.Tensor] = None,
  1038. output_attentions: Optional[bool] = None,
  1039. output_hidden_states: Optional[bool] = None,
  1040. return_dict: Optional[bool] = None,
  1041. ) -> Union[tuple, BaseModelOutput]:
  1042. """
  1043. Args:
  1044. hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`):
  1045. Features extracted from the speech or text input by the encoder prenet.
  1046. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1047. Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
  1048. `[0, 1]`:
  1049. - 1 for tokens that are **not masked**,
  1050. - 0 for tokens that are **masked**.
  1051. [What are attention masks?](../glossary#attention-mask)
  1052. output_attentions (`bool`, *optional*):
  1053. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  1054. returned tensors for more detail.
  1055. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  1056. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  1057. - 1 indicates the head is **not masked**,
  1058. - 0 indicates the head is **masked**.
  1059. output_hidden_states (`bool`, *optional*):
  1060. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  1061. for more detail.
  1062. return_dict (`bool`, *optional*):
  1063. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1064. """
  1065. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1066. output_hidden_states = (
  1067. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1068. )
  1069. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1070. # expand attention_mask
  1071. if attention_mask is not None:
  1072. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  1073. attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
  1074. hidden_states = self.layer_norm(hidden_states)
  1075. hidden_states = self.dropout(hidden_states)
  1076. position_bias = self.embed_positions(hidden_states)
  1077. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  1078. all_hidden_states = () if output_hidden_states else None
  1079. all_self_attentions = () if output_attentions else None
  1080. # check if head_mask has a correct number of layers specified if desired
  1081. if head_mask is not None:
  1082. if head_mask.size()[0] != len(self.layers):
  1083. raise ValueError(
  1084. f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
  1085. f" {head_mask.size()[0]}."
  1086. )
  1087. for idx, encoder_layer in enumerate(self.layers):
  1088. if output_hidden_states:
  1089. all_hidden_states = all_hidden_states + (hidden_states,)
  1090. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  1091. skip_the_layer = False
  1092. if self.training:
  1093. dropout_probability = torch.rand([])
  1094. skip_the_layer = dropout_probability < self.layerdrop
  1095. if not skip_the_layer or synced_gpus:
  1096. # under fsdp or deepspeed zero3 all gpus must run in sync
  1097. layer_outputs = encoder_layer(
  1098. hidden_states,
  1099. attention_mask=attention_mask,
  1100. position_bias=position_bias,
  1101. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  1102. output_attentions=output_attentions,
  1103. )
  1104. hidden_states = layer_outputs[0]
  1105. if skip_the_layer:
  1106. layer_outputs = (None, None)
  1107. if output_attentions:
  1108. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  1109. if output_hidden_states:
  1110. all_hidden_states = all_hidden_states + (hidden_states,)
  1111. if not return_dict:
  1112. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  1113. return BaseModelOutput(
  1114. last_hidden_state=hidden_states,
  1115. hidden_states=all_hidden_states,
  1116. attentions=all_self_attentions,
  1117. )
  1118. class SpeechT5EncoderWithSpeechPrenet(SpeechT5PreTrainedModel):
  1119. """
  1120. Wrapper around SpeechT5Encoder that applies SpeechT5SpeechEncoderPrenet to convert the audio waveform data to
  1121. hidden features.
  1122. """
  1123. def __init__(self, config: SpeechT5Config):
  1124. super().__init__(config)
  1125. self.prenet = SpeechT5SpeechEncoderPrenet(config)
  1126. self.wrapped_encoder = SpeechT5Encoder(config)
  1127. # Initialize weights and apply final processing
  1128. self.post_init()
  1129. def forward(
  1130. self,
  1131. input_values: torch.FloatTensor,
  1132. attention_mask: Optional[torch.Tensor] = None,
  1133. head_mask: Optional[torch.Tensor] = None,
  1134. output_attentions: Optional[bool] = None,
  1135. output_hidden_states: Optional[bool] = None,
  1136. return_dict: Optional[bool] = None,
  1137. ) -> Union[tuple, BaseModelOutput]:
  1138. hidden_states, attention_mask = self.prenet(input_values, attention_mask)
  1139. outputs = self.wrapped_encoder(
  1140. hidden_states=hidden_states,
  1141. attention_mask=attention_mask,
  1142. head_mask=head_mask,
  1143. output_attentions=output_attentions,
  1144. output_hidden_states=output_hidden_states,
  1145. return_dict=return_dict,
  1146. )
  1147. return outputs
  1148. class SpeechT5EncoderWithTextPrenet(SpeechT5PreTrainedModel):
  1149. """
  1150. Wrapper around SpeechT5Encoder that applies SpeechT5TextEncoderPrenet to convert the input_ids to hidden features.
  1151. """
  1152. def __init__(self, config: SpeechT5Config):
  1153. super().__init__(config)
  1154. self.prenet = SpeechT5TextEncoderPrenet(config)
  1155. self.wrapped_encoder = SpeechT5Encoder(config)
  1156. # Initialize weights and apply final processing
  1157. self.post_init()
  1158. def get_input_embeddings(self):
  1159. return self.prenet.get_input_embeddings()
  1160. def set_input_embeddings(self, value):
  1161. self.prenet.set_input_embeddings(value)
  1162. def forward(
  1163. self,
  1164. input_values: torch.FloatTensor,
  1165. attention_mask: Optional[torch.Tensor] = None,
  1166. head_mask: Optional[torch.Tensor] = None,
  1167. output_attentions: Optional[bool] = None,
  1168. output_hidden_states: Optional[bool] = None,
  1169. return_dict: Optional[bool] = None,
  1170. ) -> Union[tuple, BaseModelOutput]:
  1171. hidden_states = self.prenet(input_values)
  1172. outputs = self.wrapped_encoder(
  1173. hidden_states=hidden_states,
  1174. attention_mask=attention_mask,
  1175. head_mask=head_mask,
  1176. output_attentions=output_attentions,
  1177. output_hidden_states=output_hidden_states,
  1178. return_dict=return_dict,
  1179. )
  1180. return outputs
  1181. class SpeechT5EncoderWithoutPrenet(SpeechT5PreTrainedModel):
  1182. """
  1183. This wrapper class is a helper class to correctly load pretrained checkpoints when used in combination with
  1184. [`SpeechT5Model`].
  1185. """
  1186. def __init__(self, config: SpeechT5Config):
  1187. super().__init__(config)
  1188. self.wrapped_encoder = SpeechT5Encoder(config)
  1189. # Initialize weights and apply final processing
  1190. self.post_init()
  1191. def forward(
  1192. self,
  1193. input_values: torch.FloatTensor,
  1194. attention_mask: Optional[torch.Tensor] = None,
  1195. head_mask: Optional[torch.Tensor] = None,
  1196. output_attentions: Optional[bool] = None,
  1197. output_hidden_states: Optional[bool] = None,
  1198. return_dict: Optional[bool] = None,
  1199. ) -> Union[tuple, BaseModelOutput]:
  1200. return self.wrapped_encoder(
  1201. hidden_states=input_values,
  1202. attention_mask=attention_mask,
  1203. head_mask=head_mask,
  1204. output_attentions=output_attentions,
  1205. output_hidden_states=output_hidden_states,
  1206. return_dict=return_dict,
  1207. )
  1208. class SpeechT5Decoder(SpeechT5PreTrainedModel):
  1209. """
  1210. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`SpeechT5DecoderLayer`]
  1211. """
  1212. def __init__(self, config: SpeechT5Config):
  1213. super().__init__(config)
  1214. self.layerdrop = config.decoder_layerdrop
  1215. self.layers = nn.ModuleList([SpeechT5DecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
  1216. self.gradient_checkpointing = False
  1217. # Initialize weights and apply final processing
  1218. self.post_init()
  1219. def forward(
  1220. self,
  1221. hidden_states: Optional[torch.FloatTensor] = None,
  1222. attention_mask: Optional[torch.LongTensor] = None,
  1223. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  1224. encoder_attention_mask: Optional[torch.LongTensor] = None,
  1225. head_mask: Optional[torch.Tensor] = None,
  1226. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1227. past_key_values: Optional[Cache] = None,
  1228. use_cache: Optional[bool] = None,
  1229. output_attentions: Optional[bool] = None,
  1230. output_hidden_states: Optional[bool] = None,
  1231. return_dict: Optional[bool] = None,
  1232. cache_position: Optional[torch.Tensor] = None,
  1233. ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
  1234. r"""
  1235. Args:
  1236. hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, feature_size)`):
  1237. Features extracted from the speech or text input by the decoder prenet.
  1238. attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1239. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1240. - 1 for tokens that are **not masked**,
  1241. - 0 for tokens that are **masked**.
  1242. [What are attention masks?](../glossary#attention-mask)
  1243. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  1244. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  1245. of the decoder.
  1246. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
  1247. Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values
  1248. selected in `[0, 1]`:
  1249. - 1 for tokens that are **not masked**,
  1250. - 0 for tokens that are **masked**.
  1251. [What are attention masks?](../glossary#attention-mask)
  1252. head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1253. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  1254. - 1 indicates the head is **not masked**,
  1255. - 0 indicates the head is **masked**.
  1256. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1257. Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing
  1258. cross-attention on hidden heads. Mask values selected in `[0, 1]`:
  1259. - 1 indicates the head is **not masked**,
  1260. - 0 indicates the head is **masked**.
  1261. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  1262. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  1263. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
  1264. cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
  1265. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  1266. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  1267. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1268. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1269. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  1270. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  1271. than the model's internal embedding lookup matrix.
  1272. output_attentions (`bool`, *optional*):
  1273. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  1274. returned tensors for more detail.
  1275. output_hidden_states (`bool`, *optional*):
  1276. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  1277. for more detail.
  1278. return_dict (`bool`, *optional*):
  1279. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  1280. """
  1281. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1282. output_hidden_states = (
  1283. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1284. )
  1285. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1286. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1287. input_shape = hidden_states.size()[:-1]
  1288. if self.gradient_checkpointing and self.training:
  1289. if use_cache:
  1290. logger.warning_once(
  1291. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  1292. )
  1293. use_cache = False
  1294. if use_cache and past_key_values is None:
  1295. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  1296. if use_cache and isinstance(past_key_values, tuple):
  1297. logger.warning_once(
  1298. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  1299. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  1300. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  1301. )
  1302. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  1303. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  1304. attention_mask = _prepare_4d_causal_attention_mask(
  1305. attention_mask, input_shape, hidden_states, past_key_values_length
  1306. )
  1307. # expand encoder attention mask
  1308. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  1309. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  1310. encoder_attention_mask = _prepare_4d_attention_mask(
  1311. encoder_attention_mask, hidden_states.dtype, tgt_len=input_shape[-1]
  1312. )
  1313. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  1314. # decoder layers
  1315. all_hidden_states = () if output_hidden_states else None
  1316. all_self_attentions = () if output_attentions else None
  1317. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  1318. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  1319. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  1320. if attn_mask is not None:
  1321. if attn_mask.size()[0] != (len(self.layers)):
  1322. raise ValueError(
  1323. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  1324. f" {head_mask.size()[0]}."
  1325. )
  1326. for idx, decoder_layer in enumerate(self.layers):
  1327. if output_hidden_states:
  1328. all_hidden_states = all_hidden_states + (hidden_states,)
  1329. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  1330. skip_the_layer = False
  1331. if self.training:
  1332. dropout_probability = torch.rand([])
  1333. skip_the_layer = dropout_probability < self.layerdrop
  1334. if skip_the_layer and not synced_gpus:
  1335. continue
  1336. layer_outputs = decoder_layer(
  1337. hidden_states,
  1338. attention_mask,
  1339. encoder_hidden_states, # as a positional argument for gradient checkpointing
  1340. encoder_attention_mask=encoder_attention_mask,
  1341. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  1342. cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
  1343. past_key_values=past_key_values,
  1344. output_attentions=output_attentions,
  1345. use_cache=use_cache,
  1346. cache_position=cache_position,
  1347. )
  1348. hidden_states = layer_outputs[0]
  1349. if output_attentions:
  1350. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  1351. if encoder_hidden_states is not None:
  1352. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  1353. if output_hidden_states:
  1354. all_hidden_states = all_hidden_states + (hidden_states,)
  1355. if not return_dict:
  1356. return tuple(
  1357. v
  1358. for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
  1359. if v is not None
  1360. )
  1361. return BaseModelOutputWithPastAndCrossAttentions(
  1362. last_hidden_state=hidden_states,
  1363. past_key_values=past_key_values,
  1364. hidden_states=all_hidden_states,
  1365. attentions=all_self_attentions,
  1366. cross_attentions=all_cross_attentions,
  1367. )
  1368. class SpeechT5DecoderWithSpeechPrenet(SpeechT5PreTrainedModel):
  1369. """
  1370. Wrapper around SpeechT5Decoder that applies SpeechT5SpeechDecoderPrenet to convert log-mel filterbanks to hidden
  1371. features.
  1372. """
  1373. def __init__(self, config: SpeechT5Config):
  1374. super().__init__(config)
  1375. self.prenet = SpeechT5SpeechDecoderPrenet(config)
  1376. self.wrapped_decoder = SpeechT5Decoder(config)
  1377. # Initialize weights and apply final processing
  1378. self.post_init()
  1379. def forward(
  1380. self,
  1381. input_values: Optional[torch.FloatTensor] = None,
  1382. attention_mask: Optional[torch.LongTensor] = None,
  1383. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  1384. encoder_attention_mask: Optional[torch.LongTensor] = None,
  1385. speaker_embeddings: Optional[torch.Tensor] = None,
  1386. head_mask: Optional[torch.Tensor] = None,
  1387. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1388. past_key_values: Optional[Cache] = None,
  1389. use_cache: Optional[bool] = None,
  1390. output_attentions: Optional[bool] = None,
  1391. output_hidden_states: Optional[bool] = None,
  1392. return_dict: Optional[bool] = None,
  1393. cache_position: Optional[torch.Tensor] = None,
  1394. ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
  1395. decoder_hidden_states = self.prenet(input_values, speaker_embeddings)
  1396. outputs = self.wrapped_decoder(
  1397. hidden_states=decoder_hidden_states,
  1398. attention_mask=attention_mask,
  1399. encoder_hidden_states=encoder_hidden_states,
  1400. encoder_attention_mask=encoder_attention_mask,
  1401. head_mask=head_mask,
  1402. cross_attn_head_mask=cross_attn_head_mask,
  1403. past_key_values=past_key_values,
  1404. use_cache=use_cache,
  1405. output_attentions=output_attentions,
  1406. output_hidden_states=output_hidden_states,
  1407. return_dict=return_dict,
  1408. cache_position=cache_position,
  1409. )
  1410. return outputs
  1411. class SpeechT5DecoderWithTextPrenet(SpeechT5PreTrainedModel):
  1412. """
  1413. Wrapper around SpeechT5Decoder that applies SpeechT5TextDecoderPrenet to convert input tokens to hidden features.
  1414. """
  1415. def __init__(self, config: SpeechT5Config):
  1416. super().__init__(config)
  1417. self.prenet = SpeechT5TextDecoderPrenet(config)
  1418. self.wrapped_decoder = SpeechT5Decoder(config)
  1419. # Initialize weights and apply final processing
  1420. self.post_init()
  1421. def get_input_embeddings(self):
  1422. return self.prenet.get_input_embeddings()
  1423. def set_input_embeddings(self, value):
  1424. self.prenet.set_input_embeddings(value)
  1425. def forward(
  1426. self,
  1427. input_values: Optional[torch.FloatTensor] = None,
  1428. attention_mask: Optional[torch.LongTensor] = None,
  1429. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  1430. encoder_attention_mask: Optional[torch.LongTensor] = None,
  1431. head_mask: Optional[torch.Tensor] = None,
  1432. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1433. past_key_values: Optional[Cache] = None,
  1434. use_cache: Optional[bool] = None,
  1435. output_attentions: Optional[bool] = None,
  1436. output_hidden_states: Optional[bool] = None,
  1437. return_dict: Optional[bool] = None,
  1438. cache_position: Optional[torch.Tensor] = None,
  1439. ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
  1440. decoder_hidden_states, attention_mask = self.prenet(input_values, attention_mask, past_key_values)
  1441. outputs = self.wrapped_decoder(
  1442. hidden_states=decoder_hidden_states,
  1443. attention_mask=attention_mask,
  1444. encoder_hidden_states=encoder_hidden_states,
  1445. encoder_attention_mask=encoder_attention_mask,
  1446. head_mask=head_mask,
  1447. cross_attn_head_mask=cross_attn_head_mask,
  1448. past_key_values=past_key_values,
  1449. use_cache=use_cache,
  1450. output_attentions=output_attentions,
  1451. output_hidden_states=output_hidden_states,
  1452. return_dict=return_dict,
  1453. cache_position=cache_position,
  1454. )
  1455. return outputs
  1456. class SpeechT5DecoderWithoutPrenet(SpeechT5PreTrainedModel):
  1457. """
  1458. This wrapper class is a helper class to correctly load pretrained checkpoints when used in combination with
  1459. [`SpeechT5Model`].
  1460. """
  1461. def __init__(self, config: SpeechT5Config):
  1462. super().__init__(config)
  1463. self.wrapped_decoder = SpeechT5Decoder(config)
  1464. # Initialize weights and apply final processing
  1465. self.post_init()
  1466. def forward(
  1467. self,
  1468. input_values: Optional[torch.FloatTensor] = None,
  1469. attention_mask: Optional[torch.LongTensor] = None,
  1470. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  1471. encoder_attention_mask: Optional[torch.LongTensor] = None,
  1472. head_mask: Optional[torch.Tensor] = None,
  1473. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1474. past_key_values: Optional[Cache] = None,
  1475. use_cache: Optional[bool] = None,
  1476. output_attentions: Optional[bool] = None,
  1477. output_hidden_states: Optional[bool] = None,
  1478. return_dict: Optional[bool] = None,
  1479. cache_position: Optional[torch.Tensor] = None,
  1480. ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
  1481. outputs = self.wrapped_decoder(
  1482. hidden_states=input_values,
  1483. attention_mask=attention_mask,
  1484. encoder_hidden_states=encoder_hidden_states,
  1485. encoder_attention_mask=encoder_attention_mask,
  1486. head_mask=head_mask,
  1487. cross_attn_head_mask=cross_attn_head_mask,
  1488. past_key_values=past_key_values,
  1489. use_cache=use_cache,
  1490. output_attentions=output_attentions,
  1491. output_hidden_states=output_hidden_states,
  1492. return_dict=return_dict,
  1493. cache_position=cache_position,
  1494. )
  1495. return outputs
  1496. class SpeechT5GuidedMultiheadAttentionLoss(nn.Module):
  1497. """
  1498. Guided attention loss from the paper [Efficiently Trainable Text-to-Speech System Based on Deep Convolutional
  1499. Networks with Guided Attention](https://huggingface.co/papers/1710.08969), adapted for multi-head attention.
  1500. """
  1501. def __init__(self, config: SpeechT5Config):
  1502. super().__init__()
  1503. self.sigma = config.guided_attention_loss_sigma
  1504. self.scale = config.guided_attention_loss_scale
  1505. def forward(
  1506. self, attentions: torch.FloatTensor, input_masks: torch.BoolTensor, output_masks: torch.BoolTensor
  1507. ) -> torch.Tensor:
  1508. """
  1509. Compute the attention loss.
  1510. Args:
  1511. attentions (`torch.FloatTensor` of shape `(batch_size, layers * heads, output_sequence_length, input_sequence_length)`):
  1512. Batch of multi-head attention weights
  1513. input_masks (`torch.BoolTensor` of shape `(batch_size, input_sequence_length)`):
  1514. Input attention mask as booleans.
  1515. output_masks (`torch.BoolTensor` of shape `(batch_size, output_sequence_length)`):
  1516. Target attention mask as booleans.
  1517. Returns:
  1518. `torch.Tensor` with the loss value
  1519. """
  1520. guided_attn_masks = self._make_guided_attention_masks(input_masks, output_masks, attentions.device)
  1521. masks = output_masks.unsqueeze(-1) & input_masks.unsqueeze(-2)
  1522. masks = masks.to(attentions.device).unsqueeze(1)
  1523. losses = guided_attn_masks * attentions
  1524. loss = torch.mean(losses.masked_select(masks))
  1525. return self.scale * loss
  1526. def _make_guided_attention_masks(self, input_masks, output_masks, device):
  1527. input_lengths = input_masks.sum(-1)
  1528. output_lengths = output_masks.sum(-1)
  1529. guided_attn_masks = torch.zeros((len(input_masks), output_masks.shape[1], input_masks.shape[1]), device=device)
  1530. for idx, (ilen, olen) in enumerate(zip(input_lengths, output_lengths)):
  1531. guided_attn_masks[idx, :olen, :ilen] = self._make_guided_attention_mask(ilen, olen, self.sigma, device)
  1532. return guided_attn_masks.unsqueeze(1)
  1533. @staticmethod
  1534. def _make_guided_attention_mask(input_length, output_length, sigma, device):
  1535. grid_y, grid_x = torch.meshgrid(
  1536. torch.arange(input_length, device=device),
  1537. torch.arange(output_length, device=device),
  1538. indexing="xy",
  1539. )
  1540. grid_x = grid_x.float() / output_length
  1541. grid_y = grid_y.float() / input_length
  1542. return 1.0 - torch.exp(-((grid_y - grid_x) ** 2) / (2 * (sigma**2)))
  1543. class SpeechT5SpectrogramLoss(nn.Module):
  1544. """
  1545. Loss computation used by SpeechT5ForTextToSpeech.
  1546. """
  1547. def __init__(self, config: SpeechT5Config):
  1548. super().__init__()
  1549. self.use_guided_attention_loss = config.use_guided_attention_loss
  1550. self.guided_attention_loss_num_heads = config.guided_attention_loss_num_heads
  1551. self.reduction_factor = config.reduction_factor
  1552. self.l1_criterion = L1Loss()
  1553. self.bce_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(5.0))
  1554. if self.use_guided_attention_loss:
  1555. self.attn_criterion = SpeechT5GuidedMultiheadAttentionLoss(config)
  1556. def forward(
  1557. self,
  1558. attention_mask: torch.LongTensor,
  1559. outputs_before_postnet: torch.FloatTensor,
  1560. outputs_after_postnet: torch.FloatTensor,
  1561. logits: torch.FloatTensor,
  1562. labels: torch.FloatTensor,
  1563. cross_attentions: Optional[torch.FloatTensor] = None,
  1564. ) -> torch.Tensor:
  1565. padding_mask = labels != -100.0
  1566. # mask out the padded portions
  1567. labels = labels.masked_select(padding_mask)
  1568. outputs_before_postnet = outputs_before_postnet.masked_select(padding_mask)
  1569. outputs_after_postnet = outputs_after_postnet.masked_select(padding_mask)
  1570. # spectrogram loss
  1571. l1_loss = self.l1_criterion(outputs_after_postnet, labels) + self.l1_criterion(outputs_before_postnet, labels)
  1572. # construct stop labels from the padding mask
  1573. masks = padding_mask[:, :, 0]
  1574. stop_labels = torch.cat([~masks * 1.0, torch.ones(masks.size(0), 1).to(masks.device)], dim=1)
  1575. stop_labels = stop_labels[:, 1:].masked_select(masks)
  1576. logits = logits.masked_select(masks)
  1577. # stop token loss
  1578. bce_loss = self.bce_criterion(logits, stop_labels)
  1579. # combined loss
  1580. loss = l1_loss + bce_loss
  1581. # guided attention loss
  1582. if self.use_guided_attention_loss:
  1583. attn = torch.cat([x[:, : self.guided_attention_loss_num_heads] for x in cross_attentions], dim=1)
  1584. input_masks = attention_mask == 1
  1585. output_masks = padding_mask[:, :, 0]
  1586. if self.reduction_factor > 1:
  1587. output_masks = output_masks[:, self.reduction_factor - 1 :: self.reduction_factor]
  1588. attn_loss = self.attn_criterion(attn, input_masks, output_masks)
  1589. loss += attn_loss
  1590. return loss
  1591. @auto_docstring(
  1592. custom_intro="""
  1593. The bare SpeechT5 Encoder-Decoder Model outputting raw hidden-states without any specific pre- or post-nets.
  1594. """
  1595. )
  1596. class SpeechT5Model(SpeechT5PreTrainedModel):
  1597. def __init__(
  1598. self,
  1599. config: SpeechT5Config,
  1600. encoder: Optional[nn.Module] = None,
  1601. decoder: Optional[nn.Module] = None,
  1602. ):
  1603. r"""
  1604. encoder (`PreTrainedModel`, *optional*):
  1605. The encoder model to use.
  1606. decoder (`PreTrainedModel`, *optional*):
  1607. The decoder model to use.
  1608. """
  1609. super().__init__(config)
  1610. self.config = config
  1611. self.encoder = SpeechT5EncoderWithoutPrenet(config) if encoder is None else encoder
  1612. self.decoder = SpeechT5DecoderWithoutPrenet(config) if decoder is None else decoder
  1613. # Initialize weights and apply final processing
  1614. self.post_init()
  1615. def get_input_embeddings(self):
  1616. if isinstance(self.encoder, SpeechT5EncoderWithTextPrenet):
  1617. return self.encoder.get_input_embeddings()
  1618. if isinstance(self.decoder, SpeechT5DecoderWithTextPrenet):
  1619. return self.decoder.get_input_embeddings()
  1620. raise NotImplementedError
  1621. def set_input_embeddings(self, value):
  1622. if isinstance(self.encoder, SpeechT5EncoderWithTextPrenet):
  1623. self.encoder.set_input_embeddings(value)
  1624. if isinstance(self.decoder, SpeechT5DecoderWithTextPrenet):
  1625. self.decoder.set_input_embeddings(value)
  1626. def get_encoder(self):
  1627. return self.encoder
  1628. def freeze_feature_encoder(self):
  1629. """
  1630. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1631. not be updated during training.
  1632. """
  1633. if isinstance(self.encoder, SpeechT5EncoderWithSpeechPrenet):
  1634. self.encoder.prenet.freeze_feature_encoder()
  1635. @auto_docstring
  1636. def forward(
  1637. self,
  1638. input_values: Optional[torch.Tensor] = None,
  1639. attention_mask: Optional[torch.LongTensor] = None,
  1640. decoder_input_values: Optional[torch.Tensor] = None,
  1641. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1642. head_mask: Optional[torch.FloatTensor] = None,
  1643. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1644. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1645. encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
  1646. past_key_values: Optional[Cache] = None,
  1647. use_cache: Optional[bool] = None,
  1648. speaker_embeddings: Optional[torch.FloatTensor] = None,
  1649. output_attentions: Optional[bool] = None,
  1650. output_hidden_states: Optional[bool] = None,
  1651. return_dict: Optional[bool] = None,
  1652. cache_position: Optional[torch.Tensor] = None,
  1653. ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]:
  1654. r"""
  1655. input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`):
  1656. Depending on which encoder is being used, the `input_values` are either: float values of the input raw
  1657. speech waveform, or indices of input sequence tokens in the vocabulary, or hidden states.
  1658. decoder_input_values (`torch.Tensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1659. Depending on which decoder is being used, the `decoder_input_values` are either: float values of log-mel
  1660. filterbank features extracted from the raw speech waveform, or indices of decoder input sequence tokens in
  1661. the vocabulary, or hidden states.
  1662. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1663. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will
  1664. also be used by default.
  1665. If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`]
  1666. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  1667. information on the default strategy.
  1668. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1669. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1670. - 1 indicates the head is **not masked**,
  1671. - 0 indicates the head is **masked**.
  1672. speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
  1673. Tensor containing the speaker embeddings.
  1674. """
  1675. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1676. output_hidden_states = (
  1677. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1678. )
  1679. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1680. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1681. # Encode if needed (training, first prediction pass)
  1682. if encoder_outputs is None:
  1683. encoder_outputs = self.encoder(
  1684. input_values=input_values,
  1685. attention_mask=attention_mask,
  1686. head_mask=head_mask,
  1687. output_attentions=output_attentions,
  1688. output_hidden_states=output_hidden_states,
  1689. return_dict=return_dict,
  1690. )
  1691. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
  1692. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1693. encoder_outputs = BaseModelOutput(
  1694. last_hidden_state=encoder_outputs[0],
  1695. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1696. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1697. )
  1698. # downsample encoder attention mask (only for encoders with speech input)
  1699. if attention_mask is not None and isinstance(self.encoder, SpeechT5EncoderWithSpeechPrenet):
  1700. encoder_attention_mask = self.encoder.prenet._get_feature_vector_attention_mask(
  1701. encoder_outputs[0].shape[1], attention_mask
  1702. )
  1703. else:
  1704. encoder_attention_mask = attention_mask
  1705. if isinstance(self.decoder, SpeechT5DecoderWithSpeechPrenet):
  1706. decoder_args = {"speaker_embeddings": speaker_embeddings}
  1707. else:
  1708. decoder_args = {}
  1709. decoder_outputs = self.decoder(
  1710. input_values=decoder_input_values,
  1711. attention_mask=decoder_attention_mask,
  1712. encoder_hidden_states=encoder_outputs[0],
  1713. encoder_attention_mask=encoder_attention_mask,
  1714. head_mask=decoder_head_mask,
  1715. cross_attn_head_mask=cross_attn_head_mask,
  1716. past_key_values=past_key_values,
  1717. use_cache=use_cache,
  1718. output_attentions=output_attentions,
  1719. output_hidden_states=output_hidden_states,
  1720. return_dict=return_dict,
  1721. cache_position=cache_position,
  1722. **decoder_args,
  1723. )
  1724. if not return_dict:
  1725. return decoder_outputs + encoder_outputs
  1726. return Seq2SeqModelOutput(
  1727. last_hidden_state=decoder_outputs.last_hidden_state,
  1728. past_key_values=decoder_outputs.past_key_values,
  1729. decoder_hidden_states=decoder_outputs.hidden_states,
  1730. decoder_attentions=decoder_outputs.attentions,
  1731. cross_attentions=decoder_outputs.cross_attentions,
  1732. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1733. encoder_hidden_states=encoder_outputs.hidden_states,
  1734. encoder_attentions=encoder_outputs.attentions,
  1735. )
  1736. @auto_docstring(
  1737. custom_intro="""
  1738. SpeechT5 Model with a speech encoder and a text decoder.
  1739. """
  1740. )
  1741. class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel, GenerationMixin):
  1742. _tied_weights_keys = ["text_decoder_postnet.lm_head.weight"]
  1743. def __init__(self, config: SpeechT5Config):
  1744. super().__init__(config)
  1745. if config.vocab_size is None:
  1746. raise ValueError(
  1747. f"You are trying to instantiate {self.__class__} with a configuration that does not define the"
  1748. " vocabulary size of the language model head. Please instantiate the model as follows:"
  1749. " `SpeechT5ForSpeechToText.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of"
  1750. " your model's configuration."
  1751. )
  1752. speech_encoder = SpeechT5EncoderWithSpeechPrenet(config)
  1753. text_decoder = SpeechT5DecoderWithTextPrenet(config)
  1754. self.speecht5 = SpeechT5Model(config, speech_encoder, text_decoder)
  1755. self.text_decoder_postnet = SpeechT5TextDecoderPostnet(config)
  1756. # Initialize weights and apply final processing
  1757. self.post_init()
  1758. def get_encoder(self):
  1759. return self.speecht5.get_encoder()
  1760. def get_decoder(self):
  1761. return self.speecht5.get_decoder()
  1762. def freeze_feature_encoder(self):
  1763. """
  1764. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1765. not be updated during training.
  1766. """
  1767. self.get_encoder().prenet.freeze_feature_encoder()
  1768. def get_output_embeddings(self):
  1769. return self.text_decoder_postnet.get_output_embeddings()
  1770. def set_output_embeddings(self, new_embeddings):
  1771. self.text_decoder_postnet.set_output_embeddings(new_embeddings)
  1772. @auto_docstring
  1773. def forward(
  1774. self,
  1775. input_values: Optional[torch.FloatTensor] = None,
  1776. attention_mask: Optional[torch.LongTensor] = None,
  1777. decoder_input_ids: Optional[torch.LongTensor] = None,
  1778. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1779. head_mask: Optional[torch.FloatTensor] = None,
  1780. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1781. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1782. encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
  1783. past_key_values: Optional[Cache] = None,
  1784. use_cache: Optional[bool] = None,
  1785. output_attentions: Optional[bool] = None,
  1786. output_hidden_states: Optional[bool] = None,
  1787. return_dict: Optional[bool] = None,
  1788. labels: Optional[torch.LongTensor] = None,
  1789. cache_position: Optional[torch.Tensor] = None,
  1790. ) -> Union[tuple, Seq2SeqLMOutput]:
  1791. r"""
  1792. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1793. Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
  1794. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1795. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1796. To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding
  1797. and conversion into a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details.
  1798. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1799. Indices of decoder input sequence tokens in the vocabulary.
  1800. Indices can be obtained using [`SpeechT5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1801. [`PreTrainedTokenizer.__call__`] for details.
  1802. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1803. SpeechT5 uses the `eos_token_id` as the starting token for `decoder_input_ids` generation. If
  1804. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1805. `past_key_values`).
  1806. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1807. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will
  1808. also be used by default.
  1809. If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`]
  1810. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  1811. information on the default strategy.
  1812. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1813. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1814. - 1 indicates the head is **not masked**,
  1815. - 0 indicates the head is **masked**.
  1816. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1817. Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
  1818. or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
  1819. only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1820. Label indices can be obtained using [`SpeechT5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1821. [`PreTrainedTokenizer.__call__`] for details.
  1822. Example:
  1823. ```python
  1824. >>> from transformers import SpeechT5Processor, SpeechT5ForSpeechToText
  1825. >>> from datasets import load_dataset
  1826. >>> dataset = load_dataset(
  1827. ... "hf-internal-testing/librispeech_asr_demo", "clean", split="validation"
  1828. ... ) # doctest: +IGNORE_RESULT
  1829. >>> dataset = dataset.sort("id")
  1830. >>> sampling_rate = dataset.features["audio"].sampling_rate
  1831. >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_asr")
  1832. >>> model = SpeechT5ForSpeechToText.from_pretrained("microsoft/speecht5_asr")
  1833. >>> # audio file is decoded on the fly
  1834. >>> inputs = processor(audio=dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
  1835. >>> predicted_ids = model.generate(**inputs, max_length=100)
  1836. >>> # transcribe speech
  1837. >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
  1838. >>> transcription[0]
  1839. 'mister quilter is the apostle of the middle classes and we are glad to welcome his gospel'
  1840. ```
  1841. ```python
  1842. >>> inputs["labels"] = processor(text_target=dataset[0]["text"], return_tensors="pt").input_ids
  1843. >>> # compute loss
  1844. >>> loss = model(**inputs).loss
  1845. >>> round(loss.item(), 2)
  1846. 19.68
  1847. ```
  1848. """
  1849. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1850. if labels is not None:
  1851. if decoder_input_ids is None:
  1852. decoder_input_ids = shift_tokens_right(
  1853. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  1854. )
  1855. outputs = self.speecht5(
  1856. input_values=input_values,
  1857. attention_mask=attention_mask,
  1858. decoder_input_values=decoder_input_ids,
  1859. decoder_attention_mask=decoder_attention_mask,
  1860. head_mask=head_mask,
  1861. decoder_head_mask=decoder_head_mask,
  1862. cross_attn_head_mask=cross_attn_head_mask,
  1863. encoder_outputs=encoder_outputs,
  1864. past_key_values=past_key_values,
  1865. use_cache=use_cache,
  1866. output_attentions=output_attentions,
  1867. output_hidden_states=output_hidden_states,
  1868. return_dict=True,
  1869. cache_position=cache_position,
  1870. )
  1871. logits = self.text_decoder_postnet(outputs[0])
  1872. loss = None
  1873. if labels is not None:
  1874. loss_fct = CrossEntropyLoss()
  1875. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  1876. if not return_dict:
  1877. output = (logits,) + outputs[1:]
  1878. return ((loss,) + output) if loss is not None else output
  1879. return Seq2SeqLMOutput(
  1880. loss=loss,
  1881. logits=logits,
  1882. past_key_values=outputs.past_key_values,
  1883. decoder_hidden_states=outputs.decoder_hidden_states,
  1884. decoder_attentions=outputs.decoder_attentions,
  1885. cross_attentions=outputs.cross_attentions,
  1886. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1887. encoder_hidden_states=outputs.encoder_hidden_states,
  1888. encoder_attentions=outputs.encoder_attentions,
  1889. )
  1890. def _generate_speech(
  1891. model: SpeechT5PreTrainedModel,
  1892. input_values: torch.FloatTensor,
  1893. speaker_embeddings: Optional[torch.FloatTensor] = None,
  1894. attention_mask: Optional[torch.LongTensor] = None,
  1895. threshold: float = 0.5,
  1896. minlenratio: float = 0.0,
  1897. maxlenratio: float = 20.0,
  1898. vocoder: Optional[nn.Module] = None,
  1899. output_cross_attentions: bool = False,
  1900. return_output_lengths: bool = False,
  1901. ) -> Union[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor]]:
  1902. if speaker_embeddings is None:
  1903. raise ValueError(
  1904. """`speaker_embeddings` must be specified. For example, you can use a speaker embeddings by following
  1905. the code snippet provided in this link:
  1906. https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors
  1907. """
  1908. )
  1909. if attention_mask is None:
  1910. encoder_attention_mask = 1 - (input_values == model.config.pad_token_id).int()
  1911. else:
  1912. encoder_attention_mask = attention_mask
  1913. bsz = input_values.size(0)
  1914. encoder_out = model.speecht5.encoder(
  1915. input_values=input_values,
  1916. attention_mask=encoder_attention_mask,
  1917. return_dict=True,
  1918. )
  1919. encoder_last_hidden_state = encoder_out.last_hidden_state
  1920. # downsample encoder attention mask
  1921. if isinstance(model.speecht5.encoder, SpeechT5EncoderWithSpeechPrenet):
  1922. encoder_attention_mask = model.speecht5.encoder.prenet._get_feature_vector_attention_mask(
  1923. encoder_out[0].shape[1], encoder_attention_mask
  1924. )
  1925. maxlen = int(encoder_last_hidden_state.size(1) * maxlenratio / model.config.reduction_factor)
  1926. minlen = int(encoder_last_hidden_state.size(1) * minlenratio / model.config.reduction_factor)
  1927. # Start the output sequence with a mel spectrum that is all zeros.
  1928. output_sequence = encoder_last_hidden_state.new_zeros(bsz, 1, model.config.num_mel_bins)
  1929. spectrogram = []
  1930. cross_attentions = []
  1931. past_key_values = None
  1932. idx = 0
  1933. result_spectrogram = {}
  1934. while True:
  1935. idx += 1
  1936. # Run the decoder prenet on the entire output sequence.
  1937. decoder_hidden_states = model.speecht5.decoder.prenet(output_sequence, speaker_embeddings)
  1938. # Run the decoder layers on the last element of the prenet output.
  1939. decoder_out = model.speecht5.decoder.wrapped_decoder(
  1940. hidden_states=decoder_hidden_states[:, -1:],
  1941. attention_mask=None,
  1942. encoder_hidden_states=encoder_last_hidden_state,
  1943. encoder_attention_mask=encoder_attention_mask,
  1944. past_key_values=past_key_values,
  1945. use_cache=True,
  1946. output_attentions=output_cross_attentions,
  1947. return_dict=True,
  1948. )
  1949. if output_cross_attentions:
  1950. cross_attentions.append(torch.cat(decoder_out.cross_attentions, dim=0))
  1951. last_decoder_output = decoder_out.last_hidden_state.squeeze(1)
  1952. past_key_values = decoder_out.past_key_values
  1953. # Predict the new mel spectrum for this step in the sequence.
  1954. spectrum = model.speech_decoder_postnet.feat_out(last_decoder_output)
  1955. spectrum = spectrum.view(bsz, model.config.reduction_factor, model.config.num_mel_bins)
  1956. spectrogram.append(spectrum)
  1957. # Extend the output sequence with the new mel spectrum.
  1958. new_spectrogram = spectrum[:, -1, :].view(bsz, 1, model.config.num_mel_bins)
  1959. output_sequence = torch.cat((output_sequence, new_spectrogram), dim=1)
  1960. # Predict the probability that this is the stop token.
  1961. prob = torch.sigmoid(model.speech_decoder_postnet.prob_out(last_decoder_output))
  1962. if idx < minlen:
  1963. continue
  1964. else:
  1965. # If the generation loop is less than maximum length time, check the ones in the batch that have met
  1966. # the prob threshold. Otherwise, assume all have met thresholds and fill other spectrograms for the batch.
  1967. if idx < maxlen:
  1968. meet_thresholds = torch.sum(prob, dim=-1) >= threshold
  1969. meet_indexes = torch.where(meet_thresholds)[0].tolist()
  1970. else:
  1971. meet_indexes = range(len(prob))
  1972. meet_indexes = [i for i in meet_indexes if i not in result_spectrogram]
  1973. if len(meet_indexes) > 0:
  1974. spectrograms = torch.stack(spectrogram)
  1975. spectrograms = spectrograms.transpose(0, 1).flatten(1, 2)
  1976. spectrograms = model.speech_decoder_postnet.postnet(spectrograms)
  1977. for meet_index in meet_indexes:
  1978. result_spectrogram[meet_index] = spectrograms[meet_index]
  1979. if len(result_spectrogram) >= bsz:
  1980. break
  1981. spectrograms = [result_spectrogram[i] for i in range(len(result_spectrogram))]
  1982. if not return_output_lengths:
  1983. spectrogram = spectrograms[0] if bsz == 1 else torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
  1984. if vocoder is not None:
  1985. outputs = vocoder(spectrogram)
  1986. else:
  1987. outputs = spectrogram
  1988. if output_cross_attentions:
  1989. cross_attentions = torch.cat(cross_attentions, dim=2)
  1990. if bsz > 1:
  1991. cross_attentions = cross_attentions.view(
  1992. bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:]
  1993. )
  1994. outputs = (outputs, cross_attentions)
  1995. else:
  1996. # batched return values should also include the spectrogram/waveform lengths
  1997. spectrogram_lengths = []
  1998. for i in range(bsz):
  1999. spectrogram_lengths.append(spectrograms[i].size(0))
  2000. if vocoder is None:
  2001. spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
  2002. outputs = (spectrograms, spectrogram_lengths)
  2003. else:
  2004. waveforms = []
  2005. spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
  2006. waveforms = vocoder(spectrograms)
  2007. waveform_lengths = [int(waveforms.size(1) / max(spectrogram_lengths)) * i for i in spectrogram_lengths]
  2008. outputs = (waveforms, waveform_lengths)
  2009. if output_cross_attentions:
  2010. cross_attentions = torch.cat(cross_attentions, dim=2)
  2011. cross_attentions = cross_attentions.view(
  2012. bsz, int(cross_attentions.size(0) / bsz), *cross_attentions.size()[-3:]
  2013. )
  2014. outputs = (*outputs, cross_attentions)
  2015. return outputs
  2016. @auto_docstring(
  2017. custom_intro="""
  2018. SpeechT5 Model with a text encoder and a speech decoder.
  2019. """
  2020. )
  2021. class SpeechT5ForTextToSpeech(SpeechT5PreTrainedModel):
  2022. main_input_name = "input_ids"
  2023. def __init__(self, config: SpeechT5Config):
  2024. super().__init__(config)
  2025. if config.vocab_size is None:
  2026. raise ValueError(
  2027. f"You are trying to instantiate {self.__class__} with a configuration that does not define the"
  2028. " vocabulary size of the language model head. Please instantiate the model as follows:"
  2029. " `SpeechT5ForTextToSpeech.from_pretrained(..., vocab_size=vocab_size)`. or define `vocab_size` of"
  2030. " your model's configuration."
  2031. )
  2032. text_encoder = SpeechT5EncoderWithTextPrenet(config)
  2033. speech_decoder = SpeechT5DecoderWithSpeechPrenet(config)
  2034. self.speecht5 = SpeechT5Model(config, text_encoder, speech_decoder)
  2035. self.speech_decoder_postnet = SpeechT5SpeechDecoderPostnet(config)
  2036. # Initialize weights and apply final processing
  2037. self.post_init()
  2038. @classmethod
  2039. def can_generate(cls) -> bool:
  2040. # Speecht5 has a unique model structure, where the external class (`SpeechT5ForTextToSpeech`) doesn't need to inherit from
  2041. # `GenerationMixin` (it has a non-standard generation method). This means that the base `can_generate()` will return `False`,
  2042. # but we need to override it so as to do `GenerationConfig` handling in multiple parts of the codebase.
  2043. return True
  2044. def get_encoder(self):
  2045. return self.speecht5.get_encoder()
  2046. def get_decoder(self):
  2047. return self.speecht5.get_decoder()
  2048. @auto_docstring
  2049. def forward(
  2050. self,
  2051. input_ids: Optional[torch.LongTensor] = None,
  2052. attention_mask: Optional[torch.LongTensor] = None,
  2053. decoder_input_values: Optional[torch.FloatTensor] = None,
  2054. decoder_attention_mask: Optional[torch.LongTensor] = None,
  2055. head_mask: Optional[torch.FloatTensor] = None,
  2056. decoder_head_mask: Optional[torch.FloatTensor] = None,
  2057. cross_attn_head_mask: Optional[torch.Tensor] = None,
  2058. encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
  2059. past_key_values: Optional[Cache] = None,
  2060. use_cache: Optional[bool] = None,
  2061. output_attentions: Optional[bool] = None,
  2062. output_hidden_states: Optional[bool] = None,
  2063. return_dict: Optional[bool] = None,
  2064. speaker_embeddings: Optional[torch.FloatTensor] = None,
  2065. labels: Optional[torch.FloatTensor] = None,
  2066. stop_labels: Optional[torch.Tensor] = None,
  2067. cache_position: Optional[torch.Tensor] = None,
  2068. ) -> Union[tuple, Seq2SeqSpectrogramOutput]:
  2069. r"""
  2070. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  2071. Indices of input sequence tokens in the vocabulary.
  2072. Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
  2073. [`~PreTrainedTokenizer.__call__`] for details.
  2074. [What are input IDs?](../glossary#input-ids)
  2075. decoder_input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`):
  2076. Float values of input mel spectrogram.
  2077. SpeechT5 uses an all-zero spectrum as the starting token for `decoder_input_values` generation. If
  2078. `past_key_values` is used, optionally only the last `decoder_input_values` have to be input (see
  2079. `past_key_values`).
  2080. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  2081. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will
  2082. also be used by default.
  2083. If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`]
  2084. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  2085. information on the default strategy.
  2086. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  2087. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  2088. - 1 indicates the head is **not masked**,
  2089. - 0 indicates the head is **masked**.
  2090. speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
  2091. Tensor containing the speaker embeddings.
  2092. labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*):
  2093. Float values of target mel spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss
  2094. computation. Spectrograms can be obtained using [`SpeechT5Processor`]. See [`SpeechT5Processor.__call__`]
  2095. for details.
  2096. stop_labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  2097. Binary tensor indicating the position of the stop token in the sequence.
  2098. Example:
  2099. ```python
  2100. >>> from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan, set_seed
  2101. >>> import torch
  2102. >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
  2103. >>> model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts")
  2104. >>> vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
  2105. >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt")
  2106. >>> speaker_embeddings = torch.zeros((1, 512)) # or load xvectors from a file
  2107. >>> set_seed(555) # make deterministic
  2108. >>> # generate speech
  2109. >>> speech = model.generate(inputs["input_ids"], speaker_embeddings=speaker_embeddings, vocoder=vocoder)
  2110. >>> speech.shape
  2111. torch.Size([15872])
  2112. ```
  2113. """
  2114. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  2115. if labels is not None:
  2116. if decoder_input_values is None:
  2117. decoder_input_values, decoder_attention_mask = shift_spectrograms_right(
  2118. labels, self.config.reduction_factor, decoder_attention_mask
  2119. )
  2120. if self.config.use_guided_attention_loss:
  2121. output_attentions = True
  2122. outputs = self.speecht5(
  2123. input_values=input_ids,
  2124. attention_mask=attention_mask,
  2125. decoder_input_values=decoder_input_values,
  2126. decoder_attention_mask=decoder_attention_mask,
  2127. head_mask=head_mask,
  2128. decoder_head_mask=decoder_head_mask,
  2129. cross_attn_head_mask=cross_attn_head_mask,
  2130. encoder_outputs=encoder_outputs,
  2131. past_key_values=past_key_values,
  2132. use_cache=use_cache,
  2133. speaker_embeddings=speaker_embeddings,
  2134. output_attentions=output_attentions,
  2135. output_hidden_states=output_hidden_states,
  2136. return_dict=True,
  2137. cache_position=cache_position,
  2138. )
  2139. outputs_before_postnet, outputs_after_postnet, logits = self.speech_decoder_postnet(outputs[0])
  2140. loss = None
  2141. if labels is not None:
  2142. criterion = SpeechT5SpectrogramLoss(self.config)
  2143. loss = criterion(
  2144. attention_mask,
  2145. outputs_before_postnet,
  2146. outputs_after_postnet,
  2147. logits,
  2148. labels,
  2149. outputs.cross_attentions,
  2150. )
  2151. if not return_dict:
  2152. output = (outputs_after_postnet,) + outputs[1:]
  2153. return ((loss,) + output) if loss is not None else output
  2154. return Seq2SeqSpectrogramOutput(
  2155. loss=loss,
  2156. spectrogram=outputs_after_postnet,
  2157. past_key_values=outputs.past_key_values,
  2158. decoder_hidden_states=outputs.decoder_hidden_states,
  2159. decoder_attentions=outputs.decoder_attentions,
  2160. cross_attentions=outputs.cross_attentions,
  2161. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  2162. encoder_hidden_states=outputs.encoder_hidden_states,
  2163. encoder_attentions=outputs.encoder_attentions,
  2164. )
  2165. @torch.no_grad()
  2166. def generate(
  2167. self,
  2168. input_ids: torch.LongTensor,
  2169. attention_mask: Optional[torch.LongTensor] = None,
  2170. speaker_embeddings: Optional[torch.FloatTensor] = None,
  2171. threshold: float = 0.5,
  2172. minlenratio: float = 0.0,
  2173. maxlenratio: float = 20.0,
  2174. vocoder: Optional[nn.Module] = None,
  2175. output_cross_attentions: bool = False,
  2176. return_output_lengths: bool = False,
  2177. **kwargs,
  2178. ) -> Union[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor]]:
  2179. r"""
  2180. Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a
  2181. speech waveform using a vocoder.
  2182. Args:
  2183. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  2184. Indices of input sequence tokens in the vocabulary.
  2185. Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
  2186. [`~PreTrainedTokenizer.__call__`] for details.
  2187. [What are input IDs?](../glossary#input-ids)
  2188. attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  2189. Attention mask from the tokenizer, required for batched inference to signal to the model where to
  2190. ignore padded tokens from the input_ids.
  2191. speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
  2192. Tensor containing the speaker embeddings.
  2193. threshold (`float`, *optional*, defaults to 0.5):
  2194. The generated sequence ends when the predicted stop token probability exceeds this value.
  2195. minlenratio (`float`, *optional*, defaults to 0.0):
  2196. Used to calculate the minimum required length for the output sequence.
  2197. maxlenratio (`float`, *optional*, defaults to 20.0):
  2198. Used to calculate the maximum allowed length for the output sequence.
  2199. vocoder (`nn.Module`, *optional*):
  2200. The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel
  2201. spectrogram.
  2202. output_cross_attentions (`bool`, *optional*, defaults to `False`):
  2203. Whether or not to return the attentions tensors of the decoder's cross-attention layers.
  2204. return_output_lengths (`bool`, *optional*, defaults to `False`):
  2205. Whether or not to return the concrete spectrogram/waveform lengths.
  2206. Returns:
  2207. `tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
  2208. - when `return_output_lengths` is False
  2209. - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
  2210. `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
  2211. - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
  2212. `(num_frames,)` -- The predicted speech waveform.
  2213. - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
  2214. `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads,
  2215. output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
  2216. - when `return_output_lengths` is True
  2217. - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
  2218. `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that
  2219. are padded to the maximum length.
  2220. - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `list[Int]` -- A list of
  2221. all the concrete lengths for each spectrogram.
  2222. - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
  2223. `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length.
  2224. - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `list[Int]` -- A list of all
  2225. the concrete lengths for each waveform.
  2226. - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
  2227. `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads,
  2228. output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
  2229. """
  2230. if speaker_embeddings is not None:
  2231. batch_size = input_ids.size(0)
  2232. if speaker_embeddings.size(0) != batch_size:
  2233. if speaker_embeddings.size(0) == 1:
  2234. speaker_embeddings = speaker_embeddings.repeat(batch_size, 1)
  2235. else:
  2236. raise ValueError(
  2237. "The first dimension of speaker_embeddings must be either 1 or the same as batch_size."
  2238. )
  2239. return _generate_speech(
  2240. self,
  2241. input_ids,
  2242. speaker_embeddings,
  2243. attention_mask,
  2244. threshold,
  2245. minlenratio,
  2246. maxlenratio,
  2247. vocoder,
  2248. output_cross_attentions,
  2249. return_output_lengths,
  2250. )
  2251. @torch.no_grad()
  2252. def generate_speech(
  2253. self,
  2254. input_ids: torch.LongTensor,
  2255. speaker_embeddings: Optional[torch.FloatTensor] = None,
  2256. attention_mask: Optional[torch.LongTensor] = None,
  2257. threshold: float = 0.5,
  2258. minlenratio: float = 0.0,
  2259. maxlenratio: float = 20.0,
  2260. vocoder: Optional[nn.Module] = None,
  2261. output_cross_attentions: bool = False,
  2262. return_output_lengths: bool = False,
  2263. ) -> Union[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor]]:
  2264. r"""
  2265. Converts a sequence of input tokens into a sequence of mel spectrograms, which are subsequently turned into a
  2266. speech waveform using a vocoder.
  2267. Args:
  2268. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  2269. Indices of input sequence tokens in the vocabulary.
  2270. Indices can be obtained using [`SpeechT5Tokenizer`]. See [`~PreTrainedTokenizer.encode`] and
  2271. [`~PreTrainedTokenizer.__call__`] for details.
  2272. [What are input IDs?](../glossary#input-ids)
  2273. speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
  2274. Tensor containing the speaker embeddings.
  2275. attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  2276. Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
  2277. `[0, 1]`:
  2278. - 1 for tokens that are **not masked**,
  2279. - 0 for tokens that are **masked**.
  2280. [What are attention masks?](../glossary#attention-mask)
  2281. threshold (`float`, *optional*, defaults to 0.5):
  2282. The generated sequence ends when the predicted stop token probability exceeds this value.
  2283. minlenratio (`float`, *optional*, defaults to 0.0):
  2284. Used to calculate the minimum required length for the output sequence.
  2285. maxlenratio (`float`, *optional*, defaults to 20.0):
  2286. Used to calculate the maximum allowed length for the output sequence.
  2287. vocoder (`nn.Module`, *optional*, defaults to `None`):
  2288. The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel
  2289. spectrogram.
  2290. output_cross_attentions (`bool`, *optional*, defaults to `False`):
  2291. Whether or not to return the attentions tensors of the decoder's cross-attention layers.
  2292. return_output_lengths (`bool`, *optional*, defaults to `False`):
  2293. Whether or not to return the concrete spectrogram/waveform lengths.
  2294. Returns:
  2295. `tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
  2296. - when `return_output_lengths` is False
  2297. - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
  2298. `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
  2299. - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
  2300. `(num_frames,)` -- The predicted speech waveform.
  2301. - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
  2302. `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads,
  2303. output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
  2304. - when `return_output_lengths` is True
  2305. - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
  2306. `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that
  2307. are padded to the maximum length.
  2308. - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `list[Int]` -- A list of
  2309. all the concrete lengths for each spectrogram.
  2310. - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
  2311. `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length.
  2312. - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `list[Int]` -- A list of all
  2313. the concrete lengths for each waveform.
  2314. - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
  2315. `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads,
  2316. output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
  2317. """
  2318. if speaker_embeddings is not None:
  2319. batch_size = input_ids.size(0)
  2320. if speaker_embeddings.size(0) != batch_size:
  2321. if speaker_embeddings.size(0) == 1:
  2322. speaker_embeddings = speaker_embeddings.repeat(batch_size, 1)
  2323. else:
  2324. raise ValueError(
  2325. "The first dimension of speaker_embeddings must be either 1 or the same as batch size."
  2326. )
  2327. return _generate_speech(
  2328. self,
  2329. input_ids,
  2330. speaker_embeddings,
  2331. attention_mask,
  2332. threshold,
  2333. minlenratio,
  2334. maxlenratio,
  2335. vocoder,
  2336. output_cross_attentions,
  2337. return_output_lengths,
  2338. )
  2339. @auto_docstring(
  2340. custom_intro="""
  2341. SpeechT5 Model with a speech encoder and a speech decoder.
  2342. """
  2343. )
  2344. class SpeechT5ForSpeechToSpeech(SpeechT5PreTrainedModel):
  2345. def __init__(self, config: SpeechT5Config):
  2346. super().__init__(config)
  2347. speech_encoder = SpeechT5EncoderWithSpeechPrenet(config)
  2348. speech_decoder = SpeechT5DecoderWithSpeechPrenet(config)
  2349. self.speecht5 = SpeechT5Model(config, speech_encoder, speech_decoder)
  2350. self.speech_decoder_postnet = SpeechT5SpeechDecoderPostnet(config)
  2351. # Initialize weights and apply final processing
  2352. self.post_init()
  2353. def get_encoder(self):
  2354. return self.speecht5.get_encoder()
  2355. def get_decoder(self):
  2356. return self.speecht5.get_decoder()
  2357. def freeze_feature_encoder(self):
  2358. """
  2359. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  2360. not be updated during training.
  2361. """
  2362. self.get_encoder().prenet.freeze_feature_encoder()
  2363. @auto_docstring
  2364. def forward(
  2365. self,
  2366. input_values: Optional[torch.FloatTensor] = None,
  2367. attention_mask: Optional[torch.LongTensor] = None,
  2368. decoder_input_values: Optional[torch.FloatTensor] = None,
  2369. decoder_attention_mask: Optional[torch.LongTensor] = None,
  2370. head_mask: Optional[torch.FloatTensor] = None,
  2371. decoder_head_mask: Optional[torch.FloatTensor] = None,
  2372. cross_attn_head_mask: Optional[torch.Tensor] = None,
  2373. encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
  2374. past_key_values: Optional[Cache] = None,
  2375. use_cache: Optional[bool] = None,
  2376. output_attentions: Optional[bool] = None,
  2377. output_hidden_states: Optional[bool] = None,
  2378. return_dict: Optional[bool] = None,
  2379. speaker_embeddings: Optional[torch.FloatTensor] = None,
  2380. labels: Optional[torch.FloatTensor] = None,
  2381. stop_labels: Optional[torch.Tensor] = None,
  2382. cache_position: Optional[torch.Tensor] = None,
  2383. ) -> Union[tuple, Seq2SeqSpectrogramOutput]:
  2384. r"""
  2385. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  2386. Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
  2387. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  2388. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  2389. To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding and conversion into
  2390. a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details.
  2391. decoder_input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`):
  2392. Float values of input mel spectrogram.
  2393. SpeechT5 uses an all-zero spectrum as the starting token for `decoder_input_values` generation. If
  2394. `past_key_values` is used, optionally only the last `decoder_input_values` have to be input (see
  2395. `past_key_values`).
  2396. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  2397. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_values`. Causal mask will
  2398. also be used by default.
  2399. If you want to change padding behavior, you should read [`SpeechT5Decoder._prepare_decoder_attention_mask`]
  2400. and modify to your needs. See diagram 1 in [the paper](https://huggingface.co/papers/1910.13461) for more
  2401. information on the default strategy.
  2402. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  2403. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  2404. - 1 indicates the head is **not masked**,
  2405. - 0 indicates the head is **masked**.
  2406. speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
  2407. Tensor containing the speaker embeddings.
  2408. labels (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_mel_bins)`, *optional*):
  2409. Float values of target mel spectrogram. Spectrograms can be obtained using [`SpeechT5Processor`]. See
  2410. [`SpeechT5Processor.__call__`] for details.
  2411. stop_labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  2412. Binary tensor indicating the position of the stop token in the sequence.
  2413. Example:
  2414. ```python
  2415. >>> from transformers import SpeechT5Processor, SpeechT5ForSpeechToSpeech, SpeechT5HifiGan, set_seed
  2416. >>> from datasets import load_dataset
  2417. >>> import torch
  2418. >>> dataset = load_dataset(
  2419. ... "hf-internal-testing/librispeech_asr_demo", "clean", split="validation"
  2420. ... ) # doctest: +IGNORE_RESULT
  2421. >>> dataset = dataset.sort("id")
  2422. >>> sampling_rate = dataset.features["audio"].sampling_rate
  2423. >>> processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_vc")
  2424. >>> model = SpeechT5ForSpeechToSpeech.from_pretrained("microsoft/speecht5_vc")
  2425. >>> vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
  2426. >>> # audio file is decoded on the fly
  2427. >>> inputs = processor(audio=dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
  2428. >>> speaker_embeddings = torch.zeros((1, 512)) # or load xvectors from a file
  2429. >>> set_seed(555) # make deterministic
  2430. >>> # generate speech
  2431. >>> speech = model.generate_speech(inputs["input_values"], speaker_embeddings, vocoder=vocoder)
  2432. >>> speech.shape
  2433. torch.Size([77824])
  2434. ```
  2435. """
  2436. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  2437. if labels is not None:
  2438. if decoder_input_values is None:
  2439. decoder_input_values, decoder_attention_mask = shift_spectrograms_right(
  2440. labels, self.config.reduction_factor, decoder_attention_mask
  2441. )
  2442. outputs = self.speecht5(
  2443. input_values=input_values,
  2444. attention_mask=attention_mask,
  2445. decoder_input_values=decoder_input_values,
  2446. decoder_attention_mask=decoder_attention_mask,
  2447. head_mask=head_mask,
  2448. decoder_head_mask=decoder_head_mask,
  2449. cross_attn_head_mask=cross_attn_head_mask,
  2450. encoder_outputs=encoder_outputs,
  2451. past_key_values=past_key_values,
  2452. use_cache=use_cache,
  2453. speaker_embeddings=speaker_embeddings,
  2454. output_attentions=output_attentions,
  2455. output_hidden_states=output_hidden_states,
  2456. return_dict=True,
  2457. cache_position=cache_position,
  2458. )
  2459. _, spectrogram, logits = self.speech_decoder_postnet(outputs[0])
  2460. loss = None
  2461. if not return_dict:
  2462. output = (spectrogram,) + outputs[1:]
  2463. return ((loss,) + output) if loss is not None else output
  2464. return Seq2SeqSpectrogramOutput(
  2465. loss=loss,
  2466. spectrogram=spectrogram,
  2467. past_key_values=outputs.past_key_values,
  2468. decoder_hidden_states=outputs.decoder_hidden_states,
  2469. decoder_attentions=outputs.decoder_attentions,
  2470. cross_attentions=outputs.cross_attentions,
  2471. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  2472. encoder_hidden_states=outputs.encoder_hidden_states,
  2473. encoder_attentions=outputs.encoder_attentions,
  2474. )
  2475. @torch.no_grad()
  2476. def generate_speech(
  2477. self,
  2478. input_values: torch.FloatTensor,
  2479. speaker_embeddings: Optional[torch.FloatTensor] = None,
  2480. attention_mask: Optional[torch.LongTensor] = None,
  2481. threshold: float = 0.5,
  2482. minlenratio: float = 0.0,
  2483. maxlenratio: float = 20.0,
  2484. vocoder: Optional[nn.Module] = None,
  2485. output_cross_attentions: bool = False,
  2486. return_output_lengths: bool = False,
  2487. ) -> torch.FloatTensor:
  2488. r"""
  2489. Converts a raw speech waveform into a sequence of mel spectrograms, which are subsequently turned back into a
  2490. speech waveform using a vocoder.
  2491. Args:
  2492. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  2493. Float values of input raw speech waveform.
  2494. Values can be obtained by loading a *.flac* or *.wav* audio file into an array of type `list[float]`,
  2495. a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`)
  2496. or the soundfile library (`pip install soundfile`).
  2497. To prepare the array into `input_values`, the [`SpeechT5Processor`] should be used for padding and
  2498. conversion into a tensor of type `torch.FloatTensor`. See [`SpeechT5Processor.__call__`] for details.
  2499. speaker_embeddings (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_dim)`, *optional*):
  2500. Tensor containing the speaker embeddings.
  2501. attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  2502. Mask to avoid performing convolution and attention on padding token indices. Mask values selected in
  2503. `[0, 1]`:
  2504. - 1 for tokens that are **not masked**,
  2505. - 0 for tokens that are **masked**.
  2506. [What are attention masks?](../glossary#attention-mask)
  2507. threshold (`float`, *optional*, defaults to 0.5):
  2508. The generated sequence ends when the predicted stop token probability exceeds this value.
  2509. minlenratio (`float`, *optional*, defaults to 0.0):
  2510. Used to calculate the minimum required length for the output sequence.
  2511. maxlenratio (`float`, *optional*, defaults to 20.0):
  2512. Used to calculate the maximum allowed length for the output sequence.
  2513. vocoder (`nn.Module`, *optional*, defaults to `None`):
  2514. The vocoder that converts the mel spectrogram into a speech waveform. If `None`, the output is the mel
  2515. spectrogram.
  2516. output_cross_attentions (`bool`, *optional*, defaults to `False`):
  2517. Whether or not to return the attentions tensors of the decoder's cross-attention layers.
  2518. return_output_lengths (`bool`, *optional*, defaults to `False`):
  2519. Whether or not to return the concrete spectrogram/waveform lengths.
  2520. Returns:
  2521. `tuple(torch.FloatTensor)` comprising various elements depending on the inputs:
  2522. - when `return_output_lengths` is False
  2523. - **spectrogram** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
  2524. `(output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrogram.
  2525. - **waveform** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
  2526. `(num_frames,)` -- The predicted speech waveform.
  2527. - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
  2528. `torch.FloatTensor` of shape `(config.decoder_layers, config.decoder_attention_heads,
  2529. output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
  2530. - when `return_output_lengths` is True
  2531. - **spectrograms** (*optional*, returned when no `vocoder` is provided) `torch.FloatTensor` of shape
  2532. `(batch_size, output_sequence_length, config.num_mel_bins)` -- The predicted log-mel spectrograms that
  2533. are padded to the maximum length.
  2534. - **spectrogram_lengths** (*optional*, returned when no `vocoder` is provided) `list[Int]` -- A list of
  2535. all the concrete lengths for each spectrogram.
  2536. - **waveforms** (*optional*, returned when a `vocoder` is provided) `torch.FloatTensor` of shape
  2537. `(batch_size, num_frames)` -- The predicted speech waveforms that are padded to the maximum length.
  2538. - **waveform_lengths** (*optional*, returned when a `vocoder` is provided) `list[Int]` -- A list of all
  2539. the concrete lengths for each waveform.
  2540. - **cross_attentions** (*optional*, returned when `output_cross_attentions` is `True`)
  2541. `torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.decoder_attention_heads,
  2542. output_sequence_length, input_sequence_length)` -- The outputs of the decoder's cross-attention layers.
  2543. """
  2544. if speaker_embeddings is None:
  2545. speaker_embeddings = torch.zeros((1, 512), device=input_values.device)
  2546. return _generate_speech(
  2547. self,
  2548. input_values,
  2549. speaker_embeddings,
  2550. attention_mask,
  2551. threshold,
  2552. minlenratio,
  2553. maxlenratio,
  2554. vocoder,
  2555. output_cross_attentions,
  2556. return_output_lengths,
  2557. )
  2558. class HifiGanResidualBlock(nn.Module):
  2559. def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1):
  2560. super().__init__()
  2561. self.leaky_relu_slope = leaky_relu_slope
  2562. self.convs1 = nn.ModuleList(
  2563. [
  2564. nn.Conv1d(
  2565. channels,
  2566. channels,
  2567. kernel_size,
  2568. stride=1,
  2569. dilation=dilation[i],
  2570. padding=self.get_padding(kernel_size, dilation[i]),
  2571. )
  2572. for i in range(len(dilation))
  2573. ]
  2574. )
  2575. self.convs2 = nn.ModuleList(
  2576. [
  2577. nn.Conv1d(
  2578. channels,
  2579. channels,
  2580. kernel_size,
  2581. stride=1,
  2582. dilation=1,
  2583. padding=self.get_padding(kernel_size, 1),
  2584. )
  2585. for _ in range(len(dilation))
  2586. ]
  2587. )
  2588. def get_padding(self, kernel_size, dilation=1):
  2589. return (kernel_size * dilation - dilation) // 2
  2590. def apply_weight_norm(self):
  2591. weight_norm = nn.utils.weight_norm
  2592. if hasattr(nn.utils.parametrizations, "weight_norm"):
  2593. weight_norm = nn.utils.parametrizations.weight_norm
  2594. for layer in self.convs1:
  2595. weight_norm(layer)
  2596. for layer in self.convs2:
  2597. weight_norm(layer)
  2598. def remove_weight_norm(self):
  2599. for layer in self.convs1:
  2600. nn.utils.remove_weight_norm(layer)
  2601. for layer in self.convs2:
  2602. nn.utils.remove_weight_norm(layer)
  2603. def forward(self, hidden_states):
  2604. for conv1, conv2 in zip(self.convs1, self.convs2):
  2605. residual = hidden_states
  2606. hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
  2607. hidden_states = conv1(hidden_states)
  2608. hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
  2609. hidden_states = conv2(hidden_states)
  2610. hidden_states = hidden_states + residual
  2611. return hidden_states
  2612. @auto_docstring(
  2613. custom_intro="""
  2614. HiFi-GAN vocoder.
  2615. """
  2616. )
  2617. class SpeechT5HifiGan(PreTrainedModel):
  2618. config: SpeechT5HifiGanConfig
  2619. main_input_name = "spectrogram"
  2620. def __init__(self, config: SpeechT5HifiGanConfig):
  2621. super().__init__(config)
  2622. self.num_kernels = len(config.resblock_kernel_sizes)
  2623. self.num_upsamples = len(config.upsample_rates)
  2624. self.conv_pre = nn.Conv1d(
  2625. config.model_in_dim,
  2626. config.upsample_initial_channel,
  2627. kernel_size=7,
  2628. stride=1,
  2629. padding=3,
  2630. )
  2631. self.upsampler = nn.ModuleList()
  2632. for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
  2633. self.upsampler.append(
  2634. nn.ConvTranspose1d(
  2635. config.upsample_initial_channel // (2**i),
  2636. config.upsample_initial_channel // (2 ** (i + 1)),
  2637. kernel_size=kernel_size,
  2638. stride=upsample_rate,
  2639. padding=(kernel_size - upsample_rate) // 2,
  2640. )
  2641. )
  2642. self.resblocks = nn.ModuleList()
  2643. for i in range(len(self.upsampler)):
  2644. channels = config.upsample_initial_channel // (2 ** (i + 1))
  2645. for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
  2646. self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
  2647. self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3)
  2648. self.register_buffer("mean", torch.zeros(config.model_in_dim))
  2649. self.register_buffer("scale", torch.ones(config.model_in_dim))
  2650. # Initialize weights and apply final processing
  2651. self.post_init()
  2652. def _init_weights(self, module: nn.Module):
  2653. """Initialize the weights."""
  2654. if isinstance(module, (nn.Conv1d, nn.ConvTranspose1d)):
  2655. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  2656. if module.bias is not None:
  2657. module.bias.data.zero_()
  2658. def apply_weight_norm(self):
  2659. weight_norm = nn.utils.weight_norm
  2660. if hasattr(nn.utils.parametrizations, "weight_norm"):
  2661. weight_norm = nn.utils.parametrizations.weight_norm
  2662. weight_norm(self.conv_pre)
  2663. for layer in self.upsampler:
  2664. weight_norm(layer)
  2665. for layer in self.resblocks:
  2666. layer.apply_weight_norm()
  2667. weight_norm(self.conv_post)
  2668. def remove_weight_norm(self):
  2669. nn.utils.remove_weight_norm(self.conv_pre)
  2670. for layer in self.upsampler:
  2671. nn.utils.remove_weight_norm(layer)
  2672. for layer in self.resblocks:
  2673. layer.remove_weight_norm()
  2674. nn.utils.remove_weight_norm(self.conv_post)
  2675. @auto_docstring(
  2676. custom_intro="""
  2677. Converts a log-mel spectrogram into a speech waveform. Passing a batch of log-mel spectrograms returns a batch
  2678. of speech waveforms. Passing a single, un-batched log-mel spectrogram returns a single, un-batched speech
  2679. waveform.
  2680. """
  2681. )
  2682. def forward(self, spectrogram: torch.FloatTensor) -> torch.FloatTensor:
  2683. r"""
  2684. spectrogram (`torch.FloatTensor`):
  2685. Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length,
  2686. config.model_in_dim)`, or un-batched and of shape `(sequence_length, config.model_in_dim)`.
  2687. Returns:
  2688. `torch.FloatTensor`: Tensor containing the speech waveform. If the input spectrogram is batched, will be of
  2689. shape `(batch_size, num_frames,)`. If un-batched, will be of shape `(num_frames,)`.
  2690. """
  2691. if self.config.normalize_before:
  2692. spectrogram = (spectrogram - self.mean) / self.scale
  2693. is_batched = spectrogram.dim() == 3
  2694. if not is_batched:
  2695. spectrogram = spectrogram.unsqueeze(0)
  2696. hidden_states = spectrogram.transpose(2, 1)
  2697. hidden_states = self.conv_pre(hidden_states)
  2698. for i in range(self.num_upsamples):
  2699. hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
  2700. hidden_states = self.upsampler[i](hidden_states)
  2701. res_state = self.resblocks[i * self.num_kernels](hidden_states)
  2702. for j in range(1, self.num_kernels):
  2703. res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
  2704. hidden_states = res_state / self.num_kernels
  2705. hidden_states = nn.functional.leaky_relu(hidden_states)
  2706. hidden_states = self.conv_post(hidden_states)
  2707. hidden_states = torch.tanh(hidden_states)
  2708. if not is_batched:
  2709. # remove batch dim and collapse tensor to 1-d audio waveform
  2710. waveform = hidden_states.squeeze(0).transpose(1, 0).view(-1)
  2711. else:
  2712. # remove seq-len dim since this collapses to 1
  2713. waveform = hidden_states.squeeze(1)
  2714. return waveform
  2715. __all__ = [
  2716. "SpeechT5ForSpeechToText",
  2717. "SpeechT5ForSpeechToSpeech",
  2718. "SpeechT5ForTextToSpeech",
  2719. "SpeechT5Model",
  2720. "SpeechT5PreTrainedModel",
  2721. "SpeechT5HifiGan",
  2722. ]