modeling_unispeech.py 63 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/unispeech/modular_unispeech.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_unispeech.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. import warnings
  23. from dataclasses import dataclass
  24. from typing import Callable, Optional, Union
  25. import numpy as np
  26. import torch
  27. import torch.nn as nn
  28. from torch.nn import CrossEntropyLoss
  29. from ...activations import ACT2FN
  30. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  31. from ...integrations.fsdp import is_fsdp_managed_module
  32. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
  33. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  34. from ...modeling_layers import GradientCheckpointingLayer
  35. from ...modeling_outputs import (
  36. BaseModelOutput,
  37. CausalLMOutput,
  38. ModelOutput,
  39. SequenceClassifierOutput,
  40. Wav2Vec2BaseModelOutput,
  41. )
  42. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  43. from ...processing_utils import Unpack
  44. from ...utils import auto_docstring, is_torch_flex_attn_available, logging
  45. from .configuration_unispeech import UniSpeechConfig
  46. if is_torch_flex_attn_available():
  47. from ...integrations.flex_attention import make_flex_block_causal_mask
  48. logger = logging.get_logger(__name__)
  49. @dataclass
  50. @auto_docstring(
  51. custom_intro="""
  52. Output type of [`UniSpeechForPreTrainingOutput`], with potential hidden states and attentions.
  53. """
  54. )
  55. class UniSpeechForPreTrainingOutput(ModelOutput):
  56. r"""
  57. loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`):
  58. Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
  59. paper](https://huggingface.co/papers/2006.11477).
  60. projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
  61. Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
  62. projected quantized states.
  63. projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
  64. Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
  65. target vectors for contrastive loss.
  66. codevector_perplexity (`torch.FloatTensor` of shape `(1,)`):
  67. The perplexity of the codevector distribution, used to measure the diversity of the codebook.
  68. """
  69. loss: Optional[torch.FloatTensor] = None
  70. projected_states: Optional[torch.FloatTensor] = None
  71. projected_quantized_states: Optional[torch.FloatTensor] = None
  72. codevector_perplexity: Optional[torch.FloatTensor] = None
  73. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  74. attentions: Optional[tuple[torch.FloatTensor]] = None
  75. class UniSpeechSamePadLayer(nn.Module):
  76. def __init__(self, num_conv_pos_embeddings):
  77. super().__init__()
  78. self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
  79. def forward(self, hidden_states):
  80. if self.num_pad_remove > 0:
  81. hidden_states = hidden_states[:, :, : -self.num_pad_remove]
  82. return hidden_states
  83. class UniSpeechPositionalConvEmbedding(nn.Module):
  84. def __init__(self, config):
  85. super().__init__()
  86. self.conv = nn.Conv1d(
  87. config.hidden_size,
  88. config.hidden_size,
  89. kernel_size=config.num_conv_pos_embeddings,
  90. padding=config.num_conv_pos_embeddings // 2,
  91. groups=config.num_conv_pos_embedding_groups,
  92. )
  93. weight_norm = nn.utils.weight_norm
  94. if hasattr(nn.utils.parametrizations, "weight_norm"):
  95. weight_norm = nn.utils.parametrizations.weight_norm
  96. if is_deepspeed_zero3_enabled():
  97. import deepspeed
  98. with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
  99. self.conv = weight_norm(self.conv, name="weight", dim=2)
  100. if hasattr(self.conv, "parametrizations"):
  101. weight_g = self.conv.parametrizations.weight.original0
  102. weight_v = self.conv.parametrizations.weight.original1
  103. else:
  104. weight_g = self.conv.weight_g
  105. weight_v = self.conv.weight_v
  106. deepspeed.zero.register_external_parameter(self, weight_v)
  107. deepspeed.zero.register_external_parameter(self, weight_g)
  108. else:
  109. self.conv = weight_norm(self.conv, name="weight", dim=2)
  110. self.padding = UniSpeechSamePadLayer(config.num_conv_pos_embeddings)
  111. self.activation = ACT2FN[config.feat_extract_activation]
  112. def forward(self, hidden_states):
  113. hidden_states = hidden_states.transpose(1, 2)
  114. hidden_states = self.conv(hidden_states)
  115. hidden_states = self.padding(hidden_states)
  116. hidden_states = self.activation(hidden_states)
  117. hidden_states = hidden_states.transpose(1, 2)
  118. return hidden_states
  119. class UniSpeechNoLayerNormConvLayer(GradientCheckpointingLayer):
  120. def __init__(self, config, layer_id=0):
  121. super().__init__()
  122. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  123. self.out_conv_dim = config.conv_dim[layer_id]
  124. self.conv = nn.Conv1d(
  125. self.in_conv_dim,
  126. self.out_conv_dim,
  127. kernel_size=config.conv_kernel[layer_id],
  128. stride=config.conv_stride[layer_id],
  129. bias=config.conv_bias,
  130. )
  131. self.activation = ACT2FN[config.feat_extract_activation]
  132. def forward(self, hidden_states):
  133. hidden_states = self.conv(hidden_states)
  134. hidden_states = self.activation(hidden_states)
  135. return hidden_states
  136. class UniSpeechLayerNormConvLayer(GradientCheckpointingLayer):
  137. def __init__(self, config, layer_id=0):
  138. super().__init__()
  139. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  140. self.out_conv_dim = config.conv_dim[layer_id]
  141. self.conv = nn.Conv1d(
  142. self.in_conv_dim,
  143. self.out_conv_dim,
  144. kernel_size=config.conv_kernel[layer_id],
  145. stride=config.conv_stride[layer_id],
  146. bias=config.conv_bias,
  147. )
  148. self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
  149. self.activation = ACT2FN[config.feat_extract_activation]
  150. def forward(self, hidden_states):
  151. hidden_states = self.conv(hidden_states)
  152. hidden_states = hidden_states.transpose(-2, -1)
  153. hidden_states = self.layer_norm(hidden_states)
  154. hidden_states = hidden_states.transpose(-2, -1)
  155. hidden_states = self.activation(hidden_states)
  156. return hidden_states
  157. class UniSpeechGroupNormConvLayer(GradientCheckpointingLayer):
  158. def __init__(self, config, layer_id=0):
  159. super().__init__()
  160. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  161. self.out_conv_dim = config.conv_dim[layer_id]
  162. self.conv = nn.Conv1d(
  163. self.in_conv_dim,
  164. self.out_conv_dim,
  165. kernel_size=config.conv_kernel[layer_id],
  166. stride=config.conv_stride[layer_id],
  167. bias=config.conv_bias,
  168. )
  169. self.activation = ACT2FN[config.feat_extract_activation]
  170. self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
  171. def forward(self, hidden_states):
  172. hidden_states = self.conv(hidden_states)
  173. hidden_states = self.layer_norm(hidden_states)
  174. hidden_states = self.activation(hidden_states)
  175. return hidden_states
  176. class UniSpeechFeatureEncoder(nn.Module):
  177. """Construct the features from raw audio waveform"""
  178. def __init__(self, config):
  179. super().__init__()
  180. if config.feat_extract_norm == "group":
  181. conv_layers = [UniSpeechGroupNormConvLayer(config, layer_id=0)] + [
  182. UniSpeechNoLayerNormConvLayer(config, layer_id=i + 1)
  183. for i in range(config.num_feat_extract_layers - 1)
  184. ]
  185. elif config.feat_extract_norm == "layer":
  186. conv_layers = [
  187. UniSpeechLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
  188. ]
  189. else:
  190. raise ValueError(
  191. f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
  192. )
  193. self.conv_layers = nn.ModuleList(conv_layers)
  194. self.gradient_checkpointing = False
  195. self._requires_grad = True
  196. def _freeze_parameters(self):
  197. for param in self.parameters():
  198. param.requires_grad = False
  199. self._requires_grad = False
  200. def forward(self, input_values):
  201. hidden_states = input_values[:, None]
  202. # make sure hidden_states require grad for gradient_checkpointing
  203. if self._requires_grad and self.training:
  204. hidden_states.requires_grad = True
  205. for conv_layer in self.conv_layers:
  206. hidden_states = conv_layer(hidden_states)
  207. return hidden_states
  208. class UniSpeechFeatureProjection(nn.Module):
  209. def __init__(self, config):
  210. super().__init__()
  211. self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
  212. self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
  213. self.dropout = nn.Dropout(config.feat_proj_dropout)
  214. def forward(self, hidden_states):
  215. # non-projected hidden states are needed for quantization
  216. norm_hidden_states = self.layer_norm(hidden_states)
  217. hidden_states = self.projection(norm_hidden_states)
  218. hidden_states = self.dropout(hidden_states)
  219. return hidden_states, norm_hidden_states
  220. def eager_attention_forward(
  221. module: nn.Module,
  222. query: torch.Tensor,
  223. key: torch.Tensor,
  224. value: torch.Tensor,
  225. attention_mask: Optional[torch.Tensor],
  226. scaling: Optional[float] = None,
  227. dropout: float = 0.0,
  228. head_mask: Optional[torch.Tensor] = None,
  229. **kwargs,
  230. ):
  231. if scaling is None:
  232. scaling = query.size(-1) ** -0.5
  233. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  234. if attention_mask is not None:
  235. attn_weights = attn_weights + attention_mask
  236. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  237. if head_mask is not None:
  238. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  239. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  240. attn_output = torch.matmul(attn_weights, value)
  241. attn_output = attn_output.transpose(1, 2).contiguous()
  242. return attn_output, attn_weights
  243. class UniSpeechAttention(nn.Module):
  244. """Multi-headed attention from 'Attention Is All You Need' paper"""
  245. def __init__(
  246. self,
  247. embed_dim: int,
  248. num_heads: int,
  249. dropout: float = 0.0,
  250. is_decoder: bool = False,
  251. bias: bool = True,
  252. is_causal: bool = False,
  253. config: Optional[UniSpeechConfig] = None,
  254. ):
  255. super().__init__()
  256. self.embed_dim = embed_dim
  257. self.num_heads = num_heads
  258. self.dropout = dropout
  259. self.head_dim = embed_dim // num_heads
  260. self.config = config
  261. if (self.head_dim * num_heads) != self.embed_dim:
  262. raise ValueError(
  263. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  264. f" and `num_heads`: {num_heads})."
  265. )
  266. self.scaling = self.head_dim**-0.5
  267. self.is_decoder = is_decoder
  268. self.is_causal = is_causal
  269. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  270. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  271. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  272. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  273. def forward(
  274. self,
  275. hidden_states: torch.Tensor,
  276. key_value_states: Optional[torch.Tensor] = None,
  277. attention_mask: Optional[torch.Tensor] = None,
  278. layer_head_mask: Optional[torch.Tensor] = None,
  279. output_attentions: Optional[bool] = False,
  280. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  281. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  282. **kwargs: Unpack[FlashAttentionKwargs],
  283. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  284. """Input shape: Batch x Time x Channel"""
  285. # if key_value_states are provided this layer is used as a cross-attention layer
  286. # for the decoder
  287. is_cross_attention = key_value_states is not None
  288. # determine input shapes
  289. bsz, tgt_len = hidden_states.shape[:-1]
  290. src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
  291. q_input_shape = (bsz, tgt_len, -1, self.head_dim)
  292. kv_input_shape = (bsz, src_len, -1, self.head_dim)
  293. # get query proj
  294. query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
  295. current_states = key_value_states if is_cross_attention else hidden_states
  296. key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
  297. value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
  298. attention_interface: Callable = eager_attention_forward
  299. if self.config._attn_implementation != "eager":
  300. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  301. attn_output, attn_weights = attention_interface(
  302. self,
  303. query_states,
  304. key_states,
  305. value_states,
  306. attention_mask,
  307. dropout=0.0 if not self.training else self.dropout,
  308. scaling=self.scaling,
  309. output_attentions=output_attentions,
  310. head_mask=layer_head_mask,
  311. **kwargs,
  312. )
  313. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  314. attn_output = self.out_proj(attn_output)
  315. return attn_output, attn_weights, None
  316. class UniSpeechFeedForward(nn.Module):
  317. def __init__(self, config):
  318. super().__init__()
  319. self.intermediate_dropout = nn.Dropout(config.activation_dropout)
  320. self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
  321. if isinstance(config.hidden_act, str):
  322. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  323. else:
  324. self.intermediate_act_fn = config.hidden_act
  325. self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
  326. self.output_dropout = nn.Dropout(config.hidden_dropout)
  327. def forward(self, hidden_states):
  328. hidden_states = self.intermediate_dense(hidden_states)
  329. hidden_states = self.intermediate_act_fn(hidden_states)
  330. hidden_states = self.intermediate_dropout(hidden_states)
  331. hidden_states = self.output_dense(hidden_states)
  332. hidden_states = self.output_dropout(hidden_states)
  333. return hidden_states
  334. class UniSpeechEncoderLayer(GradientCheckpointingLayer):
  335. def __init__(self, config):
  336. super().__init__()
  337. self.attention = UniSpeechAttention(
  338. embed_dim=config.hidden_size,
  339. num_heads=config.num_attention_heads,
  340. dropout=config.attention_dropout,
  341. is_decoder=False,
  342. config=config,
  343. )
  344. self.dropout = nn.Dropout(config.hidden_dropout)
  345. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  346. self.feed_forward = UniSpeechFeedForward(config)
  347. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  348. def forward(self, hidden_states, attention_mask=None, output_attentions=False):
  349. attn_residual = hidden_states
  350. hidden_states, attn_weights, _ = self.attention(
  351. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  352. )
  353. hidden_states = self.dropout(hidden_states)
  354. hidden_states = attn_residual + hidden_states
  355. hidden_states = self.layer_norm(hidden_states)
  356. hidden_states = hidden_states + self.feed_forward(hidden_states)
  357. hidden_states = self.final_layer_norm(hidden_states)
  358. outputs = (hidden_states,)
  359. if output_attentions:
  360. outputs += (attn_weights,)
  361. return outputs
  362. class UniSpeechEncoder(nn.Module):
  363. def __init__(self, config):
  364. super().__init__()
  365. self.config = config
  366. self.pos_conv_embed = UniSpeechPositionalConvEmbedding(config)
  367. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  368. self.dropout = nn.Dropout(config.hidden_dropout)
  369. self.layers = nn.ModuleList([UniSpeechEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  370. self.gradient_checkpointing = False
  371. def forward(
  372. self,
  373. hidden_states: torch.tensor,
  374. attention_mask: Optional[torch.Tensor] = None,
  375. output_attentions: bool = False,
  376. output_hidden_states: bool = False,
  377. return_dict: bool = True,
  378. ):
  379. all_hidden_states = () if output_hidden_states else None
  380. all_self_attentions = () if output_attentions else None
  381. if attention_mask is not None:
  382. # make sure padded tokens output 0
  383. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  384. hidden_states[~expand_attention_mask] = 0
  385. attention_mask = self._update_full_mask(
  386. attention_mask,
  387. hidden_states,
  388. )
  389. position_embeddings = self.pos_conv_embed(hidden_states)
  390. hidden_states = hidden_states + position_embeddings
  391. hidden_states = self.layer_norm(hidden_states)
  392. hidden_states = self.dropout(hidden_states)
  393. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  394. for layer in self.layers:
  395. if output_hidden_states:
  396. all_hidden_states = all_hidden_states + (hidden_states,)
  397. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  398. dropout_probability = torch.rand([])
  399. skip_the_layer = self.training and dropout_probability < self.config.layerdrop
  400. if not skip_the_layer or synced_gpus:
  401. # under fsdp or deepspeed zero3 all gpus must run in sync
  402. layer_outputs = layer(
  403. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  404. )
  405. hidden_states = layer_outputs[0]
  406. if skip_the_layer:
  407. layer_outputs = (None, None)
  408. if output_attentions:
  409. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  410. if output_hidden_states:
  411. all_hidden_states = all_hidden_states + (hidden_states,)
  412. if not return_dict:
  413. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  414. return BaseModelOutput(
  415. last_hidden_state=hidden_states,
  416. hidden_states=all_hidden_states,
  417. attentions=all_self_attentions,
  418. )
  419. def _update_full_mask(
  420. self,
  421. attention_mask: Union[torch.Tensor, None],
  422. inputs_embeds: torch.Tensor,
  423. ):
  424. if attention_mask is not None:
  425. if self.config._attn_implementation == "flash_attention_2":
  426. attention_mask = attention_mask if 0 in attention_mask else None
  427. elif self.config._attn_implementation == "sdpa":
  428. # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
  429. # the manual implementation that requires a 4D causal mask in all cases.
  430. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  431. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
  432. elif self.config._attn_implementation == "flex_attention":
  433. if isinstance(attention_mask, torch.Tensor):
  434. attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
  435. else:
  436. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  437. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  438. return attention_mask
  439. class UniSpeechAttnAdapterLayer(nn.Module):
  440. def __init__(self, config):
  441. """
  442. Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
  443. up training throughput.
  444. """
  445. super().__init__()
  446. self.input_dim = config.adapter_attn_dim
  447. self.hidden_dim = config.hidden_size
  448. self.norm = nn.LayerNorm(self.hidden_dim)
  449. self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
  450. self.act_fn = nn.ReLU()
  451. self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
  452. def forward(self, hidden_states: torch.FloatTensor):
  453. hidden_states = self.norm(hidden_states)
  454. hidden_states = self.linear_1(hidden_states)
  455. hidden_states = self.act_fn(hidden_states)
  456. hidden_states = self.linear_2(hidden_states)
  457. return hidden_states
  458. class UniSpeechEncoderLayerStableLayerNorm(GradientCheckpointingLayer):
  459. def __init__(self, config):
  460. super().__init__()
  461. self.attention = UniSpeechAttention(
  462. embed_dim=config.hidden_size,
  463. num_heads=config.num_attention_heads,
  464. dropout=config.attention_dropout,
  465. is_decoder=False,
  466. config=config,
  467. )
  468. self.dropout = nn.Dropout(config.hidden_dropout)
  469. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  470. self.feed_forward = UniSpeechFeedForward(config)
  471. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  472. if getattr(config, "adapter_attn_dim", None) is not None:
  473. self.adapter_layer = UniSpeechAttnAdapterLayer(config)
  474. else:
  475. self.adapter_layer = None
  476. def forward(
  477. self,
  478. hidden_states: torch.Tensor,
  479. attention_mask: Optional[torch.Tensor] = None,
  480. output_attentions: bool = False,
  481. ):
  482. attn_residual = hidden_states
  483. hidden_states = self.layer_norm(hidden_states)
  484. hidden_states, attn_weights, _ = self.attention(
  485. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  486. )
  487. hidden_states = self.dropout(hidden_states)
  488. hidden_states = attn_residual + hidden_states
  489. hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
  490. if self.adapter_layer is not None:
  491. hidden_states = hidden_states + self.adapter_layer(hidden_states)
  492. outputs = (hidden_states,)
  493. if output_attentions:
  494. outputs += (attn_weights,)
  495. return outputs
  496. class UniSpeechEncoderStableLayerNorm(nn.Module):
  497. def __init__(self, config):
  498. super().__init__()
  499. self.config = config
  500. self.pos_conv_embed = UniSpeechPositionalConvEmbedding(config)
  501. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  502. self.dropout = nn.Dropout(config.hidden_dropout)
  503. self.layers = nn.ModuleList(
  504. [UniSpeechEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
  505. )
  506. self.gradient_checkpointing = False
  507. def forward(
  508. self,
  509. hidden_states,
  510. attention_mask=None,
  511. output_attentions=False,
  512. output_hidden_states=False,
  513. return_dict=True,
  514. ):
  515. all_hidden_states = () if output_hidden_states else None
  516. all_self_attentions = () if output_attentions else None
  517. if attention_mask is not None:
  518. # make sure padded tokens output 0
  519. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  520. hidden_states[~expand_attention_mask] = 0
  521. attention_mask = self._update_full_mask(
  522. attention_mask,
  523. hidden_states,
  524. )
  525. position_embeddings = self.pos_conv_embed(hidden_states)
  526. hidden_states = hidden_states + position_embeddings
  527. hidden_states = self.dropout(hidden_states)
  528. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  529. for layer in self.layers:
  530. if output_hidden_states:
  531. all_hidden_states = all_hidden_states + (hidden_states,)
  532. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  533. dropout_probability = torch.rand([])
  534. skip_the_layer = self.training and dropout_probability < self.config.layerdrop
  535. if not skip_the_layer or synced_gpus:
  536. # under fsdp or deepspeed zero3 all gpus must run in sync
  537. # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
  538. layer_outputs = layer(
  539. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  540. )
  541. hidden_states = layer_outputs[0]
  542. if skip_the_layer:
  543. layer_outputs = (None, None)
  544. if output_attentions:
  545. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  546. hidden_states = self.layer_norm(hidden_states)
  547. if output_hidden_states:
  548. all_hidden_states = all_hidden_states + (hidden_states,)
  549. if not return_dict:
  550. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  551. return BaseModelOutput(
  552. last_hidden_state=hidden_states,
  553. hidden_states=all_hidden_states,
  554. attentions=all_self_attentions,
  555. )
  556. def _update_full_mask(
  557. self,
  558. attention_mask: Union[torch.Tensor, None],
  559. inputs_embeds: torch.Tensor,
  560. ):
  561. if attention_mask is not None:
  562. if self.config._attn_implementation == "flash_attention_2":
  563. attention_mask = attention_mask if 0 in attention_mask else None
  564. elif self.config._attn_implementation == "sdpa":
  565. # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
  566. # the manual implementation that requires a 4D causal mask in all cases.
  567. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  568. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
  569. elif self.config._attn_implementation == "flex_attention":
  570. if isinstance(attention_mask, torch.Tensor):
  571. attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
  572. else:
  573. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  574. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  575. return attention_mask
  576. class UniSpeechGumbelVectorQuantizer(nn.Module):
  577. """
  578. Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
  579. GUMBEL-SOFTMAX](https://huggingface.co/papers/1611.01144) for more information.
  580. """
  581. def __init__(self, config):
  582. super().__init__()
  583. self.num_groups = config.num_codevector_groups
  584. self.num_vars = config.num_codevectors_per_group
  585. if config.codevector_dim % self.num_groups != 0:
  586. raise ValueError(
  587. f"`config.codevector_dim {config.codevector_dim} must be divisible "
  588. f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
  589. )
  590. # storage for codebook variables (codewords)
  591. self.codevectors = nn.Parameter(
  592. torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
  593. )
  594. self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
  595. # can be decayed for training
  596. self.temperature = 2
  597. @staticmethod
  598. def _compute_perplexity(probs):
  599. marginal_probs = probs.mean(dim=0)
  600. perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
  601. return perplexity
  602. def forward(self, hidden_states):
  603. batch_size, sequence_length, hidden_size = hidden_states.shape
  604. # project to codevector dim
  605. hidden_states = self.weight_proj(hidden_states)
  606. hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
  607. if self.training:
  608. # sample code vector probs via gumbel in differentiateable way
  609. codevector_probs = nn.functional.gumbel_softmax(
  610. hidden_states.float(), tau=self.temperature, hard=True
  611. ).type_as(hidden_states)
  612. # compute perplexity
  613. codevector_soft_dist = torch.softmax(
  614. hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
  615. )
  616. perplexity = self._compute_perplexity(codevector_soft_dist)
  617. else:
  618. # take argmax in non-differentiable way
  619. # comptute hard codevector distribution (one hot)
  620. codevector_idx = hidden_states.argmax(dim=-1)
  621. codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
  622. -1, codevector_idx.view(-1, 1), 1.0
  623. )
  624. codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
  625. perplexity = self._compute_perplexity(codevector_probs)
  626. codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
  627. # use probs to retrieve codevectors
  628. codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
  629. codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
  630. codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
  631. return codevectors, perplexity
  632. @auto_docstring
  633. class UniSpeechPreTrainedModel(PreTrainedModel):
  634. config: UniSpeechConfig
  635. base_model_prefix = "unispeech"
  636. main_input_name = "input_values"
  637. supports_gradient_checkpointing = True
  638. _supports_flash_attn = True
  639. _supports_sdpa = True
  640. _supports_flex_attn = True
  641. def _init_weights(self, module):
  642. """Initialize the weights"""
  643. # gumbel softmax requires special init
  644. if isinstance(module, UniSpeechGumbelVectorQuantizer):
  645. module.weight_proj.weight.data.normal_(mean=0.0, std=1)
  646. module.weight_proj.bias.data.zero_()
  647. nn.init.uniform_(module.codevectors)
  648. elif isinstance(module, UniSpeechPositionalConvEmbedding):
  649. nn.init.normal_(
  650. module.conv.weight,
  651. mean=0,
  652. std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
  653. )
  654. nn.init.constant_(module.conv.bias, 0)
  655. elif isinstance(module, UniSpeechFeatureProjection):
  656. k = math.sqrt(1 / module.projection.in_features)
  657. nn.init.uniform_(module.projection.weight, a=-k, b=k)
  658. nn.init.uniform_(module.projection.bias, a=-k, b=k)
  659. elif isinstance(module, nn.Linear):
  660. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  661. if module.bias is not None:
  662. module.bias.data.zero_()
  663. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  664. module.bias.data.zero_()
  665. module.weight.data.fill_(1.0)
  666. elif isinstance(module, nn.Conv1d):
  667. nn.init.kaiming_normal_(module.weight)
  668. if module.bias is not None:
  669. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  670. nn.init.uniform_(module.bias, a=-k, b=k)
  671. def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
  672. """
  673. Computes the output length of the convolutional layers
  674. """
  675. def _conv_out_length(input_length, kernel_size, stride):
  676. # 1D convolutional layer output length formula taken
  677. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  678. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  679. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  680. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  681. return input_lengths
  682. def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
  683. # Effectively attention_mask.sum(-1), but not inplace to be able to run
  684. # on inference mode.
  685. non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
  686. output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
  687. batch_size = attention_mask.shape[0]
  688. attention_mask = torch.zeros(
  689. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  690. )
  691. # these two operations makes sure that all values before the output lengths idxs are attended to
  692. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  693. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  694. return attention_mask
  695. def _compute_mask_indices(
  696. shape: tuple[int, int],
  697. mask_prob: float,
  698. mask_length: int,
  699. attention_mask: Optional[torch.LongTensor] = None,
  700. min_masks: int = 0,
  701. ) -> np.ndarray:
  702. """
  703. Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
  704. ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
  705. CPU as part of the preprocessing during training.
  706. Args:
  707. shape: The shape for which to compute masks. This should be of a tuple of size 2 where
  708. the first element is the batch size and the second element is the length of the axis to span.
  709. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
  710. independently generated mask spans of length `mask_length` is computed by
  711. `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
  712. actual percentage will be smaller.
  713. mask_length: size of the mask
  714. min_masks: minimum number of masked spans
  715. attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
  716. each batch dimension.
  717. """
  718. batch_size, sequence_length = shape
  719. if mask_length < 1:
  720. raise ValueError("`mask_length` has to be bigger than 0.")
  721. if mask_length > sequence_length:
  722. raise ValueError(
  723. f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
  724. f" and `sequence_length`: {sequence_length}`"
  725. )
  726. # epsilon is used for probabilistic rounding
  727. epsilon = np.random.rand(1).item()
  728. def compute_num_masked_span(input_length):
  729. """Given input length, compute how many spans should be masked"""
  730. num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
  731. num_masked_span = max(num_masked_span, min_masks)
  732. # make sure num masked span <= sequence_length
  733. if num_masked_span * mask_length > sequence_length:
  734. num_masked_span = sequence_length // mask_length
  735. # make sure num_masked span is also <= input_length - (mask_length - 1)
  736. if input_length - (mask_length - 1) < num_masked_span:
  737. num_masked_span = max(input_length - (mask_length - 1), 0)
  738. return num_masked_span
  739. # compute number of masked spans in batch
  740. input_lengths = (
  741. attention_mask.detach().sum(-1).tolist()
  742. if attention_mask is not None
  743. else [sequence_length for _ in range(batch_size)]
  744. )
  745. # SpecAugment mask to fill
  746. spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
  747. spec_aug_mask_idxs = []
  748. max_num_masked_span = compute_num_masked_span(sequence_length)
  749. if max_num_masked_span == 0:
  750. return spec_aug_mask
  751. for input_length in input_lengths:
  752. # compute num of masked spans for this input
  753. num_masked_span = compute_num_masked_span(input_length)
  754. # get random indices to mask
  755. spec_aug_mask_idx = np.random.choice(
  756. np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
  757. )
  758. # pick first sampled index that will serve as a dummy index to pad vector
  759. # to ensure same dimension for all batches due to probabilistic rounding
  760. # Picking first sample just pads those vectors twice.
  761. if len(spec_aug_mask_idx) == 0:
  762. # this case can only happen if `input_length` is strictly smaller then
  763. # `sequence_length` in which case the last token has to be a padding
  764. # token which we can use as a dummy mask id
  765. dummy_mask_idx = sequence_length - 1
  766. else:
  767. dummy_mask_idx = spec_aug_mask_idx[0]
  768. spec_aug_mask_idx = np.concatenate(
  769. [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
  770. )
  771. spec_aug_mask_idxs.append(spec_aug_mask_idx)
  772. spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
  773. # expand masked indices to masked spans
  774. spec_aug_mask_idxs = np.broadcast_to(
  775. spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
  776. )
  777. spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
  778. # add offset to the starting indexes so that indexes now create a span
  779. offsets = np.arange(mask_length)[None, None, :]
  780. offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
  781. batch_size, max_num_masked_span * mask_length
  782. )
  783. spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
  784. # ensure that we cannot have indices larger than sequence_length
  785. if spec_aug_mask_idxs.max() > sequence_length - 1:
  786. spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
  787. # scatter indices to mask
  788. np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
  789. return spec_aug_mask
  790. UniSpeechBaseModelOutput = Wav2Vec2BaseModelOutput
  791. @auto_docstring
  792. class UniSpeechModel(UniSpeechPreTrainedModel):
  793. def __init__(self, config: UniSpeechConfig):
  794. super().__init__(config)
  795. self.config = config
  796. self.feature_extractor = UniSpeechFeatureEncoder(config)
  797. self.feature_projection = UniSpeechFeatureProjection(config)
  798. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  799. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  800. if config.do_stable_layer_norm:
  801. self.encoder = UniSpeechEncoderStableLayerNorm(config)
  802. else:
  803. self.encoder = UniSpeechEncoder(config)
  804. # Initialize weights and apply final processing
  805. self.post_init()
  806. def _mask_hidden_states(
  807. self,
  808. hidden_states: torch.FloatTensor,
  809. mask_time_indices: Optional[torch.FloatTensor] = None,
  810. attention_mask: Optional[torch.LongTensor] = None,
  811. ):
  812. """
  813. Masks extracted features along time axis and/or along feature axis according to
  814. [SpecAugment](https://huggingface.co/papers/1904.08779).
  815. """
  816. # `config.apply_spec_augment` can set masking to False
  817. if not getattr(self.config, "apply_spec_augment", True):
  818. return hidden_states
  819. # generate indices & apply SpecAugment along time axis
  820. batch_size, sequence_length, hidden_size = hidden_states.size()
  821. if mask_time_indices is not None:
  822. # apply SpecAugment along time axis with given mask_time_indices
  823. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  824. elif self.config.mask_time_prob > 0 and self.training:
  825. mask_time_indices = _compute_mask_indices(
  826. (batch_size, sequence_length),
  827. mask_prob=self.config.mask_time_prob,
  828. mask_length=self.config.mask_time_length,
  829. attention_mask=attention_mask,
  830. min_masks=self.config.mask_time_min_masks,
  831. )
  832. mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
  833. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  834. if self.config.mask_feature_prob > 0 and self.training:
  835. # generate indices & apply SpecAugment along feature axis
  836. mask_feature_indices = _compute_mask_indices(
  837. (batch_size, hidden_size),
  838. mask_prob=self.config.mask_feature_prob,
  839. mask_length=self.config.mask_feature_length,
  840. min_masks=self.config.mask_feature_min_masks,
  841. )
  842. mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
  843. mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
  844. hidden_states[mask_feature_indices] = 0
  845. return hidden_states
  846. @auto_docstring
  847. def forward(
  848. self,
  849. input_values: Optional[torch.Tensor],
  850. attention_mask: Optional[torch.Tensor] = None,
  851. mask_time_indices: Optional[torch.FloatTensor] = None,
  852. output_attentions: Optional[bool] = None,
  853. output_hidden_states: Optional[bool] = None,
  854. return_dict: Optional[bool] = None,
  855. ) -> Union[tuple, UniSpeechBaseModelOutput]:
  856. r"""
  857. mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  858. Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
  859. masked extracted features in *config.proj_codevector_dim* space.
  860. """
  861. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  862. output_hidden_states = (
  863. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  864. )
  865. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  866. extract_features = self.feature_extractor(input_values)
  867. extract_features = extract_features.transpose(1, 2)
  868. if attention_mask is not None:
  869. # compute reduced attention_mask corresponding to feature vectors
  870. attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
  871. hidden_states, extract_features = self.feature_projection(extract_features)
  872. hidden_states = self._mask_hidden_states(
  873. hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
  874. )
  875. encoder_outputs = self.encoder(
  876. hidden_states,
  877. attention_mask=attention_mask,
  878. output_attentions=output_attentions,
  879. output_hidden_states=output_hidden_states,
  880. return_dict=return_dict,
  881. )
  882. hidden_states = encoder_outputs[0]
  883. if not return_dict:
  884. return (hidden_states, extract_features) + encoder_outputs[1:]
  885. return UniSpeechBaseModelOutput(
  886. last_hidden_state=hidden_states,
  887. extract_features=extract_features,
  888. hidden_states=encoder_outputs.hidden_states,
  889. attentions=encoder_outputs.attentions,
  890. )
  891. @auto_docstring(
  892. custom_intro="""
  893. UniSpeech Model with a vector-quantization module and ctc loss for pre-training.
  894. """
  895. )
  896. class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
  897. def __init__(self, config: UniSpeechConfig):
  898. super().__init__(config)
  899. self.unispeech = UniSpeechModel(config)
  900. self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
  901. self.quantizer = UniSpeechGumbelVectorQuantizer(config)
  902. self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
  903. self.project_hid = nn.Linear(config.proj_codevector_dim, config.hidden_size)
  904. self.ctc_proj = nn.Linear(config.hidden_size, config.num_ctc_classes)
  905. self.dropout = nn.Dropout(config.final_dropout)
  906. # Initialize weights and apply final processing
  907. self.post_init()
  908. def set_gumbel_temperature(self, temperature: int):
  909. """
  910. Set the Gumbel softmax temperature to a given value. Only necessary for training
  911. """
  912. self.quantizer.temperature = temperature
  913. def freeze_feature_extractor(self):
  914. """
  915. Calling this function will disable the gradient computation for the feature encoder so that its parameters will
  916. not be updated during training.
  917. """
  918. warnings.warn(
  919. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  920. "Please use the equivalent `freeze_feature_encoder` method instead.",
  921. FutureWarning,
  922. )
  923. self.freeze_feature_encoder()
  924. def freeze_feature_encoder(self):
  925. """
  926. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  927. not be updated during training.
  928. """
  929. self.unispeech.feature_extractor._freeze_parameters()
  930. @staticmethod
  931. def compute_contrastive_logits(
  932. target_features: torch.FloatTensor,
  933. negative_features: torch.FloatTensor,
  934. predicted_features: torch.FloatTensor,
  935. temperature: int = 1,
  936. ):
  937. """
  938. Compute logits for contrastive loss based using cosine similarity as the distance measure between
  939. `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
  940. """
  941. target_features = torch.cat([target_features, negative_features], dim=0)
  942. logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1)
  943. logits = logits.type_as(target_features)
  944. # apply temperature
  945. logits = logits / temperature
  946. return logits
  947. @auto_docstring
  948. def forward(
  949. self,
  950. input_values: Optional[torch.Tensor],
  951. attention_mask: Optional[torch.Tensor] = None,
  952. output_attentions: Optional[bool] = None,
  953. output_hidden_states: Optional[bool] = None,
  954. return_dict: Optional[bool] = None,
  955. ) -> Union[tuple, UniSpeechForPreTrainingOutput]:
  956. r"""
  957. Example:
  958. ```python
  959. >>> import torch
  960. >>> from transformers import AutoFeatureExtractor, UniSpeechForPreTraining
  961. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-large-1500h-cv")
  962. >>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv")
  963. >>> # TODO: Add full pretraining example
  964. ```"""
  965. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  966. outputs = self.unispeech(
  967. input_values,
  968. attention_mask=attention_mask,
  969. output_attentions=output_attentions,
  970. output_hidden_states=output_hidden_states,
  971. return_dict=return_dict,
  972. )
  973. transformer_features = outputs[0]
  974. # quantize all (unmasked) extracted features and project to final vq dim
  975. extract_features = self.dropout_features(outputs[1])
  976. quantized_features, codevector_perplexity = self.quantizer(extract_features)
  977. # project quantized features twice
  978. quantized_features = self.project_q(quantized_features.to(self.project_q.weight.dtype))
  979. quantized_features = self.project_hid(quantized_features)
  980. prob_replace_matrix = torch.empty(transformer_features.size(0), transformer_features.size(1)).fill_(
  981. self.config.replace_prob
  982. )
  983. prob_replace_matrix = prob_replace_matrix.transpose(0, 1)
  984. sampled_replace_matrix = torch.bernoulli(prob_replace_matrix).bool().to(transformer_features.device)
  985. sampled_replace_matrix = sampled_replace_matrix.transpose(0, 1)
  986. sampled_replace_matrix = sampled_replace_matrix.unsqueeze(-1)
  987. logits = transformer_features.masked_fill(sampled_replace_matrix, 0.0) + (
  988. quantized_features.masked_fill(~sampled_replace_matrix, 0.0)
  989. )
  990. # project to ctc units
  991. logits = self.dropout(logits)
  992. logits = self.ctc_proj(logits)
  993. # TODO(PVP) - add negative sampling & loss computation
  994. loss = None
  995. if not return_dict:
  996. if loss is not None:
  997. return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
  998. return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
  999. return UniSpeechForPreTrainingOutput(
  1000. loss=loss,
  1001. projected_states=transformer_features,
  1002. projected_quantized_states=quantized_features,
  1003. codevector_perplexity=codevector_perplexity,
  1004. hidden_states=outputs.hidden_states,
  1005. attentions=outputs.attentions,
  1006. )
  1007. _HIDDEN_STATES_START_POSITION = 2
  1008. @auto_docstring(
  1009. custom_intro="""
  1010. UniSpeech Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
  1011. """
  1012. )
  1013. class UniSpeechForCTC(UniSpeechPreTrainedModel):
  1014. def __init__(self, config, target_lang: Optional[str] = None):
  1015. r"""
  1016. target_lang (`str`, *optional*):
  1017. Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
  1018. adapter.<lang>.bin. Only relevant when using an instance of [`UniSpeechForCTC`] with adapters. Uses 'eng' by
  1019. default.
  1020. """
  1021. super().__init__(config)
  1022. self.unispeech = UniSpeechModel(config)
  1023. self.dropout = nn.Dropout(config.final_dropout)
  1024. self.target_lang = target_lang
  1025. if config.vocab_size is None:
  1026. raise ValueError(
  1027. f"You are trying to instantiate {self.__class__} with a configuration that "
  1028. "does not define the vocabulary size of the language model head. Please "
  1029. "instantiate the model as follows: `UniSpeechForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
  1030. "or define `vocab_size` of your model's configuration."
  1031. )
  1032. output_hidden_size = (
  1033. config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
  1034. )
  1035. self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
  1036. # Initialize weights and apply final processing
  1037. self.post_init()
  1038. def tie_weights(self):
  1039. """
  1040. This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
  1041. passing `target_lang=...` to `from_pretrained(...)`.
  1042. This method is **not** supposed to be called by the user and is prone to be changed in the future.
  1043. """
  1044. # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
  1045. # correctly load adapter layers for UniSpeech so that we do not have to introduce a new API to
  1046. # [`PreTrainedModel`]. While slightly hacky, UniSpeech never has to tie input and output embeddings, so that it is
  1047. # ok to repurpose this function here.
  1048. target_lang = self.target_lang
  1049. if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
  1050. raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
  1051. elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
  1052. logger.info("By default `target_lang` is set to 'eng'.")
  1053. elif target_lang is not None:
  1054. self.load_adapter(target_lang, force_load=True)
  1055. def freeze_feature_extractor(self):
  1056. """
  1057. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1058. not be updated during training.
  1059. """
  1060. warnings.warn(
  1061. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1062. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1063. FutureWarning,
  1064. )
  1065. self.freeze_feature_encoder()
  1066. def freeze_feature_encoder(self):
  1067. """
  1068. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1069. not be updated during training.
  1070. """
  1071. self.unispeech.feature_extractor._freeze_parameters()
  1072. def freeze_base_model(self):
  1073. """
  1074. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1075. be updated during training. Only the classification head will be updated.
  1076. """
  1077. for param in self.unispeech.parameters():
  1078. param.requires_grad = False
  1079. @auto_docstring
  1080. def forward(
  1081. self,
  1082. input_values: Optional[torch.Tensor],
  1083. attention_mask: Optional[torch.Tensor] = None,
  1084. output_attentions: Optional[bool] = None,
  1085. output_hidden_states: Optional[bool] = None,
  1086. return_dict: Optional[bool] = None,
  1087. labels: Optional[torch.Tensor] = None,
  1088. ) -> Union[tuple, CausalLMOutput]:
  1089. r"""
  1090. labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
  1091. Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
  1092. the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
  1093. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
  1094. config.vocab_size - 1]`.
  1095. """
  1096. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1097. if labels is not None and labels.max() >= self.config.vocab_size:
  1098. raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
  1099. outputs = self.unispeech(
  1100. input_values,
  1101. attention_mask=attention_mask,
  1102. output_attentions=output_attentions,
  1103. output_hidden_states=output_hidden_states,
  1104. return_dict=return_dict,
  1105. )
  1106. hidden_states = outputs[0]
  1107. hidden_states = self.dropout(hidden_states)
  1108. logits = self.lm_head(hidden_states)
  1109. loss = None
  1110. if labels is not None:
  1111. # retrieve loss input_lengths from attention_mask
  1112. attention_mask = (
  1113. attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
  1114. )
  1115. input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  1116. # assuming that padded tokens are filled with -100
  1117. # when not being attended to
  1118. labels_mask = labels >= 0
  1119. target_lengths = labels_mask.sum(-1)
  1120. flattened_targets = labels.masked_select(labels_mask)
  1121. # ctc_loss doesn't support fp16
  1122. log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
  1123. with torch.backends.cudnn.flags(enabled=False):
  1124. loss = nn.functional.ctc_loss(
  1125. log_probs,
  1126. flattened_targets,
  1127. input_lengths,
  1128. target_lengths,
  1129. blank=self.config.pad_token_id,
  1130. reduction=self.config.ctc_loss_reduction,
  1131. zero_infinity=self.config.ctc_zero_infinity,
  1132. )
  1133. if not return_dict:
  1134. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1135. return ((loss,) + output) if loss is not None else output
  1136. return CausalLMOutput(
  1137. loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  1138. )
  1139. @auto_docstring(
  1140. custom_intro="""
  1141. UniSpeech Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
  1142. SUPERB Keyword Spotting.
  1143. """
  1144. )
  1145. class UniSpeechForSequenceClassification(UniSpeechPreTrainedModel):
  1146. def __init__(self, config):
  1147. super().__init__(config)
  1148. if hasattr(config, "add_adapter") and config.add_adapter:
  1149. raise ValueError(
  1150. "Sequence classification does not support the use of UniSpeech adapters (config.add_adapter=True)"
  1151. )
  1152. self.unispeech = UniSpeechModel(config)
  1153. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1154. if config.use_weighted_layer_sum:
  1155. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1156. self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
  1157. self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
  1158. # Initialize weights and apply final processing
  1159. self.post_init()
  1160. def freeze_feature_extractor(self):
  1161. """
  1162. Calling this function will disable the gradient computation for the feature encoder so that its parameters will
  1163. not be updated during training.
  1164. """
  1165. warnings.warn(
  1166. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1167. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1168. FutureWarning,
  1169. )
  1170. self.freeze_feature_encoder()
  1171. def freeze_feature_encoder(self):
  1172. """
  1173. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1174. not be updated during training.
  1175. """
  1176. self.unispeech.feature_extractor._freeze_parameters()
  1177. def freeze_base_model(self):
  1178. """
  1179. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1180. be updated during training. Only the classification head will be updated.
  1181. """
  1182. for param in self.unispeech.parameters():
  1183. param.requires_grad = False
  1184. @auto_docstring
  1185. def forward(
  1186. self,
  1187. input_values: Optional[torch.Tensor],
  1188. attention_mask: Optional[torch.Tensor] = None,
  1189. output_attentions: Optional[bool] = None,
  1190. output_hidden_states: Optional[bool] = None,
  1191. return_dict: Optional[bool] = None,
  1192. labels: Optional[torch.Tensor] = None,
  1193. ) -> Union[tuple, SequenceClassifierOutput]:
  1194. r"""
  1195. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1196. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  1197. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1198. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1199. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  1200. into a tensor of type `torch.FloatTensor`. See [`UniSpeechProcessor.__call__`] for details.
  1201. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1202. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1203. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1204. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1205. """
  1206. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1207. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1208. outputs = self.unispeech(
  1209. input_values,
  1210. attention_mask=attention_mask,
  1211. output_attentions=output_attentions,
  1212. output_hidden_states=output_hidden_states,
  1213. return_dict=return_dict,
  1214. )
  1215. if self.config.use_weighted_layer_sum:
  1216. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1217. hidden_states = torch.stack(hidden_states, dim=1)
  1218. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1219. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1220. else:
  1221. hidden_states = outputs[0]
  1222. hidden_states = self.projector(hidden_states)
  1223. if attention_mask is None:
  1224. pooled_output = hidden_states.mean(dim=1)
  1225. else:
  1226. padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
  1227. expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  1228. hidden_states[~expand_padding_mask] = 0.0
  1229. pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
  1230. logits = self.classifier(pooled_output)
  1231. loss = None
  1232. if labels is not None:
  1233. loss_fct = CrossEntropyLoss()
  1234. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1235. if not return_dict:
  1236. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1237. return ((loss,) + output) if loss is not None else output
  1238. return SequenceClassifierOutput(
  1239. loss=loss,
  1240. logits=logits,
  1241. hidden_states=outputs.hidden_states,
  1242. attentions=outputs.attentions,
  1243. )
  1244. __all__ = [
  1245. "UniSpeechForCTC",
  1246. "UniSpeechForPreTraining",
  1247. "UniSpeechForSequenceClassification",
  1248. "UniSpeechModel",
  1249. "UniSpeechPreTrainedModel",
  1250. ]