modeling_wav2vec2.py 99 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353
  1. # coding=utf-8
  2. # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Wav2Vec2 model."""
  16. import math
  17. import warnings
  18. from dataclasses import dataclass
  19. from typing import Callable, Optional, Union
  20. import numpy as np
  21. import torch
  22. from safetensors.torch import load_file as safe_load_file
  23. from torch import nn
  24. from torch.nn import CrossEntropyLoss
  25. from ...activations import ACT2FN
  26. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  27. from ...integrations.fsdp import is_fsdp_managed_module
  28. from ...modeling_attn_mask_utils import (
  29. _prepare_4d_attention_mask,
  30. _prepare_4d_attention_mask_for_sdpa,
  31. )
  32. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  33. from ...modeling_layers import GradientCheckpointingLayer
  34. from ...modeling_outputs import (
  35. BaseModelOutput,
  36. CausalLMOutput,
  37. MaskedLMOutput,
  38. SequenceClassifierOutput,
  39. TokenClassifierOutput,
  40. Wav2Vec2BaseModelOutput,
  41. XVectorOutput,
  42. )
  43. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  44. from ...processing_utils import Unpack
  45. from ...utils import (
  46. ModelOutput,
  47. auto_docstring,
  48. cached_file,
  49. check_torch_load_is_safe,
  50. is_peft_available,
  51. is_torch_flex_attn_available,
  52. logging,
  53. )
  54. from .configuration_wav2vec2 import Wav2Vec2Config
  55. WAV2VEC2_ADAPTER_PT_FILE = "adapter.{}.bin"
  56. WAV2VEC2_ADAPTER_SAFE_FILE = "adapter.{}.safetensors"
  57. if is_torch_flex_attn_available():
  58. from ...integrations.flex_attention import make_flex_block_causal_mask
  59. logger = logging.get_logger(__name__)
  60. _HIDDEN_STATES_START_POSITION = 2
  61. @dataclass
  62. @auto_docstring(
  63. custom_intro="""
  64. Output type of [`Wav2Vec2ForPreTraining`], with potential hidden states and attentions.
  65. """
  66. )
  67. class Wav2Vec2ForPreTrainingOutput(ModelOutput):
  68. r"""
  69. loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
  70. Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
  71. paper](https://huggingface.co/papers/2006.11477).
  72. projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
  73. Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
  74. projected quantized states.
  75. projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
  76. Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
  77. target vectors for contrastive loss.
  78. codevector_perplexity (`torch.FloatTensor` of shape `(1,)`):
  79. The perplexity of the codevector distribution, used to measure the diversity of the codebook.
  80. contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
  81. The contrastive loss (L_m) as stated in the [official paper](https://huggingface.co/papers/2006.11477).
  82. diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`):
  83. The diversity loss (L_d) as stated in the [official paper](https://huggingface.co/papers/2006.11477).
  84. """
  85. loss: Optional[torch.FloatTensor] = None
  86. projected_states: Optional[torch.FloatTensor] = None
  87. projected_quantized_states: Optional[torch.FloatTensor] = None
  88. codevector_perplexity: Optional[torch.FloatTensor] = None
  89. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  90. attentions: Optional[tuple[torch.FloatTensor]] = None
  91. contrastive_loss: Optional[torch.FloatTensor] = None
  92. diversity_loss: Optional[torch.FloatTensor] = None
  93. def _compute_mask_indices(
  94. shape: tuple[int, int],
  95. mask_prob: float,
  96. mask_length: int,
  97. attention_mask: Optional[torch.LongTensor] = None,
  98. min_masks: int = 0,
  99. ) -> np.ndarray:
  100. """
  101. Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
  102. ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
  103. CPU as part of the preprocessing during training.
  104. Args:
  105. shape: The shape for which to compute masks. This should be of a tuple of size 2 where
  106. the first element is the batch size and the second element is the length of the axis to span.
  107. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
  108. independently generated mask spans of length `mask_length` is computed by
  109. `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
  110. actual percentage will be smaller.
  111. mask_length: size of the mask
  112. min_masks: minimum number of masked spans
  113. attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
  114. each batch dimension.
  115. """
  116. batch_size, sequence_length = shape
  117. if mask_length < 1:
  118. raise ValueError("`mask_length` has to be bigger than 0.")
  119. if mask_length > sequence_length:
  120. raise ValueError(
  121. f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
  122. f" and `sequence_length`: {sequence_length}`"
  123. )
  124. # epsilon is used for probabilistic rounding
  125. epsilon = np.random.rand(1).item()
  126. def compute_num_masked_span(input_length):
  127. """Given input length, compute how many spans should be masked"""
  128. num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
  129. num_masked_span = max(num_masked_span, min_masks)
  130. # make sure num masked span <= sequence_length
  131. if num_masked_span * mask_length > sequence_length:
  132. num_masked_span = sequence_length // mask_length
  133. # make sure num_masked span is also <= input_length - (mask_length - 1)
  134. if input_length - (mask_length - 1) < num_masked_span:
  135. num_masked_span = max(input_length - (mask_length - 1), 0)
  136. return num_masked_span
  137. # compute number of masked spans in batch
  138. input_lengths = (
  139. attention_mask.detach().sum(-1).tolist()
  140. if attention_mask is not None
  141. else [sequence_length for _ in range(batch_size)]
  142. )
  143. # SpecAugment mask to fill
  144. spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
  145. spec_aug_mask_idxs = []
  146. max_num_masked_span = compute_num_masked_span(sequence_length)
  147. if max_num_masked_span == 0:
  148. return spec_aug_mask
  149. for input_length in input_lengths:
  150. # compute num of masked spans for this input
  151. num_masked_span = compute_num_masked_span(input_length)
  152. # get random indices to mask
  153. spec_aug_mask_idx = np.random.choice(
  154. np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
  155. )
  156. # pick first sampled index that will serve as a dummy index to pad vector
  157. # to ensure same dimension for all batches due to probabilistic rounding
  158. # Picking first sample just pads those vectors twice.
  159. if len(spec_aug_mask_idx) == 0:
  160. # this case can only happen if `input_length` is strictly smaller then
  161. # `sequence_length` in which case the last token has to be a padding
  162. # token which we can use as a dummy mask id
  163. dummy_mask_idx = sequence_length - 1
  164. else:
  165. dummy_mask_idx = spec_aug_mask_idx[0]
  166. spec_aug_mask_idx = np.concatenate(
  167. [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
  168. )
  169. spec_aug_mask_idxs.append(spec_aug_mask_idx)
  170. spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
  171. # expand masked indices to masked spans
  172. spec_aug_mask_idxs = np.broadcast_to(
  173. spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
  174. )
  175. spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
  176. # add offset to the starting indexes so that indexes now create a span
  177. offsets = np.arange(mask_length)[None, None, :]
  178. offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
  179. batch_size, max_num_masked_span * mask_length
  180. )
  181. spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
  182. # ensure that we cannot have indices larger than sequence_length
  183. if spec_aug_mask_idxs.max() > sequence_length - 1:
  184. spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
  185. # scatter indices to mask
  186. np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
  187. return spec_aug_mask
  188. def _sample_negative_indices(
  189. features_shape: tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None
  190. ):
  191. """
  192. Sample `num_negatives` vectors from feature vectors.
  193. """
  194. batch_size, sequence_length = features_shape
  195. # generate indices of the positive vectors themselves, repeat them `num_negatives` times
  196. sequence_length_range = np.arange(sequence_length)
  197. # get `num_negatives` random vector indices from the same utterance
  198. sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32)
  199. mask_time_indices = (
  200. mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool)
  201. )
  202. for batch_idx in range(batch_size):
  203. high = mask_time_indices[batch_idx].sum() - 1
  204. mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]]
  205. feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives))
  206. sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives))
  207. # avoid sampling the same positive vector, but keep the distribution uniform
  208. sampled_indices[sampled_indices >= feature_indices] += 1
  209. # remap to actual indices
  210. sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices]
  211. # correct for batch size
  212. sampled_negative_indices[batch_idx] += batch_idx * sequence_length
  213. return sampled_negative_indices
  214. class Wav2Vec2NoLayerNormConvLayer(GradientCheckpointingLayer):
  215. def __init__(self, config, layer_id=0):
  216. super().__init__()
  217. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  218. self.out_conv_dim = config.conv_dim[layer_id]
  219. self.conv = nn.Conv1d(
  220. self.in_conv_dim,
  221. self.out_conv_dim,
  222. kernel_size=config.conv_kernel[layer_id],
  223. stride=config.conv_stride[layer_id],
  224. bias=config.conv_bias,
  225. )
  226. self.activation = ACT2FN[config.feat_extract_activation]
  227. def forward(self, hidden_states):
  228. hidden_states = self.conv(hidden_states)
  229. hidden_states = self.activation(hidden_states)
  230. return hidden_states
  231. class Wav2Vec2LayerNormConvLayer(GradientCheckpointingLayer):
  232. def __init__(self, config, layer_id=0):
  233. super().__init__()
  234. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  235. self.out_conv_dim = config.conv_dim[layer_id]
  236. self.conv = nn.Conv1d(
  237. self.in_conv_dim,
  238. self.out_conv_dim,
  239. kernel_size=config.conv_kernel[layer_id],
  240. stride=config.conv_stride[layer_id],
  241. bias=config.conv_bias,
  242. )
  243. self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
  244. self.activation = ACT2FN[config.feat_extract_activation]
  245. def forward(self, hidden_states):
  246. hidden_states = self.conv(hidden_states)
  247. hidden_states = hidden_states.transpose(-2, -1)
  248. hidden_states = self.layer_norm(hidden_states)
  249. hidden_states = hidden_states.transpose(-2, -1)
  250. hidden_states = self.activation(hidden_states)
  251. return hidden_states
  252. class Wav2Vec2GroupNormConvLayer(GradientCheckpointingLayer):
  253. def __init__(self, config, layer_id=0):
  254. super().__init__()
  255. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  256. self.out_conv_dim = config.conv_dim[layer_id]
  257. self.conv = nn.Conv1d(
  258. self.in_conv_dim,
  259. self.out_conv_dim,
  260. kernel_size=config.conv_kernel[layer_id],
  261. stride=config.conv_stride[layer_id],
  262. bias=config.conv_bias,
  263. )
  264. self.activation = ACT2FN[config.feat_extract_activation]
  265. self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
  266. def forward(self, hidden_states):
  267. hidden_states = self.conv(hidden_states)
  268. hidden_states = self.layer_norm(hidden_states)
  269. hidden_states = self.activation(hidden_states)
  270. return hidden_states
  271. class Wav2Vec2PositionalConvEmbedding(nn.Module):
  272. def __init__(self, config):
  273. super().__init__()
  274. self.conv = nn.Conv1d(
  275. config.hidden_size,
  276. config.hidden_size,
  277. kernel_size=config.num_conv_pos_embeddings,
  278. padding=config.num_conv_pos_embeddings // 2,
  279. groups=config.num_conv_pos_embedding_groups,
  280. )
  281. weight_norm = nn.utils.weight_norm
  282. if hasattr(nn.utils.parametrizations, "weight_norm"):
  283. weight_norm = nn.utils.parametrizations.weight_norm
  284. if is_deepspeed_zero3_enabled():
  285. import deepspeed
  286. with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
  287. self.conv = weight_norm(self.conv, name="weight", dim=2)
  288. if hasattr(self.conv, "parametrizations"):
  289. weight_g = self.conv.parametrizations.weight.original0
  290. weight_v = self.conv.parametrizations.weight.original1
  291. else:
  292. weight_g = self.conv.weight_g
  293. weight_v = self.conv.weight_v
  294. deepspeed.zero.register_external_parameter(self, weight_v)
  295. deepspeed.zero.register_external_parameter(self, weight_g)
  296. else:
  297. self.conv = weight_norm(self.conv, name="weight", dim=2)
  298. self.padding = Wav2Vec2SamePadLayer(config.num_conv_pos_embeddings)
  299. self.activation = ACT2FN[config.feat_extract_activation]
  300. def forward(self, hidden_states):
  301. hidden_states = hidden_states.transpose(1, 2)
  302. hidden_states = self.conv(hidden_states)
  303. hidden_states = self.padding(hidden_states)
  304. hidden_states = self.activation(hidden_states)
  305. hidden_states = hidden_states.transpose(1, 2)
  306. return hidden_states
  307. class Wav2Vec2SamePadLayer(nn.Module):
  308. def __init__(self, num_conv_pos_embeddings):
  309. super().__init__()
  310. self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
  311. def forward(self, hidden_states):
  312. if self.num_pad_remove > 0:
  313. hidden_states = hidden_states[:, :, : -self.num_pad_remove]
  314. return hidden_states
  315. class Wav2Vec2FeatureEncoder(nn.Module):
  316. """Construct the features from raw audio waveform"""
  317. def __init__(self, config):
  318. super().__init__()
  319. if config.feat_extract_norm == "group":
  320. conv_layers = [Wav2Vec2GroupNormConvLayer(config, layer_id=0)] + [
  321. Wav2Vec2NoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
  322. ]
  323. elif config.feat_extract_norm == "layer":
  324. conv_layers = [
  325. Wav2Vec2LayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)
  326. ]
  327. else:
  328. raise ValueError(
  329. f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
  330. )
  331. self.conv_layers = nn.ModuleList(conv_layers)
  332. self.gradient_checkpointing = False
  333. self._requires_grad = True
  334. def _freeze_parameters(self):
  335. for param in self.parameters():
  336. param.requires_grad = False
  337. self._requires_grad = False
  338. def forward(self, input_values):
  339. hidden_states = input_values[:, None]
  340. # make sure hidden_states require grad for gradient_checkpointing
  341. if self._requires_grad and self.training:
  342. hidden_states.requires_grad = True
  343. for conv_layer in self.conv_layers:
  344. hidden_states = conv_layer(hidden_states)
  345. return hidden_states
  346. class Wav2Vec2FeatureExtractor(Wav2Vec2FeatureEncoder):
  347. def __init__(self, config):
  348. super().__init__(config)
  349. warnings.warn(
  350. f"The class `{self.__class__.__name__}` has been depreciated "
  351. "and will be removed in Transformers v5. "
  352. f"Use `{self.__class__.__bases__[0].__name__}` instead.",
  353. FutureWarning,
  354. )
  355. class Wav2Vec2FeatureProjection(nn.Module):
  356. def __init__(self, config):
  357. super().__init__()
  358. self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
  359. self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
  360. self.dropout = nn.Dropout(config.feat_proj_dropout)
  361. def forward(self, hidden_states):
  362. # non-projected hidden states are needed for quantization
  363. norm_hidden_states = self.layer_norm(hidden_states)
  364. hidden_states = self.projection(norm_hidden_states)
  365. hidden_states = self.dropout(hidden_states)
  366. return hidden_states, norm_hidden_states
  367. # Copied from transformers.models.bart.modeling_bart.eager_attention_forward
  368. def eager_attention_forward(
  369. module: nn.Module,
  370. query: torch.Tensor,
  371. key: torch.Tensor,
  372. value: torch.Tensor,
  373. attention_mask: Optional[torch.Tensor],
  374. scaling: Optional[float] = None,
  375. dropout: float = 0.0,
  376. head_mask: Optional[torch.Tensor] = None,
  377. **kwargs,
  378. ):
  379. if scaling is None:
  380. scaling = query.size(-1) ** -0.5
  381. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  382. if attention_mask is not None:
  383. attn_weights = attn_weights + attention_mask
  384. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  385. if head_mask is not None:
  386. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  387. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  388. attn_output = torch.matmul(attn_weights, value)
  389. attn_output = attn_output.transpose(1, 2).contiguous()
  390. return attn_output, attn_weights
  391. class Wav2Vec2Attention(nn.Module):
  392. """Multi-headed attention from 'Attention Is All You Need' paper"""
  393. def __init__(
  394. self,
  395. embed_dim: int,
  396. num_heads: int,
  397. dropout: float = 0.0,
  398. is_decoder: bool = False,
  399. bias: bool = True,
  400. is_causal: bool = False,
  401. config: Optional[Wav2Vec2Config] = None,
  402. ):
  403. super().__init__()
  404. self.embed_dim = embed_dim
  405. self.num_heads = num_heads
  406. self.dropout = dropout
  407. self.head_dim = embed_dim // num_heads
  408. self.config = config
  409. if (self.head_dim * num_heads) != self.embed_dim:
  410. raise ValueError(
  411. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  412. f" and `num_heads`: {num_heads})."
  413. )
  414. self.scaling = self.head_dim**-0.5
  415. self.is_decoder = is_decoder
  416. self.is_causal = is_causal
  417. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  418. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  419. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  420. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  421. def forward(
  422. self,
  423. hidden_states: torch.Tensor,
  424. key_value_states: Optional[torch.Tensor] = None,
  425. attention_mask: Optional[torch.Tensor] = None,
  426. layer_head_mask: Optional[torch.Tensor] = None,
  427. output_attentions: Optional[bool] = False,
  428. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  429. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  430. **kwargs: Unpack[FlashAttentionKwargs],
  431. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  432. """Input shape: Batch x Time x Channel"""
  433. # if key_value_states are provided this layer is used as a cross-attention layer
  434. # for the decoder
  435. is_cross_attention = key_value_states is not None
  436. # determine input shapes
  437. bsz, tgt_len = hidden_states.shape[:-1]
  438. src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
  439. q_input_shape = (bsz, tgt_len, -1, self.head_dim)
  440. kv_input_shape = (bsz, src_len, -1, self.head_dim)
  441. # get query proj
  442. query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
  443. current_states = key_value_states if is_cross_attention else hidden_states
  444. key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
  445. value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
  446. attention_interface: Callable = eager_attention_forward
  447. if self.config._attn_implementation != "eager":
  448. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  449. attn_output, attn_weights = attention_interface(
  450. self,
  451. query_states,
  452. key_states,
  453. value_states,
  454. attention_mask,
  455. dropout=0.0 if not self.training else self.dropout,
  456. scaling=self.scaling,
  457. output_attentions=output_attentions,
  458. head_mask=layer_head_mask,
  459. **kwargs,
  460. )
  461. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  462. attn_output = self.out_proj(attn_output)
  463. return attn_output, attn_weights, None
  464. class Wav2Vec2FeedForward(nn.Module):
  465. def __init__(self, config):
  466. super().__init__()
  467. self.intermediate_dropout = nn.Dropout(config.activation_dropout)
  468. self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
  469. if isinstance(config.hidden_act, str):
  470. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  471. else:
  472. self.intermediate_act_fn = config.hidden_act
  473. self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
  474. self.output_dropout = nn.Dropout(config.hidden_dropout)
  475. def forward(self, hidden_states):
  476. hidden_states = self.intermediate_dense(hidden_states)
  477. hidden_states = self.intermediate_act_fn(hidden_states)
  478. hidden_states = self.intermediate_dropout(hidden_states)
  479. hidden_states = self.output_dense(hidden_states)
  480. hidden_states = self.output_dropout(hidden_states)
  481. return hidden_states
  482. class Wav2Vec2EncoderLayer(GradientCheckpointingLayer):
  483. def __init__(self, config):
  484. super().__init__()
  485. self.attention = Wav2Vec2Attention(
  486. embed_dim=config.hidden_size,
  487. num_heads=config.num_attention_heads,
  488. dropout=config.attention_dropout,
  489. is_decoder=False,
  490. config=config,
  491. )
  492. self.dropout = nn.Dropout(config.hidden_dropout)
  493. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  494. self.feed_forward = Wav2Vec2FeedForward(config)
  495. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  496. def forward(self, hidden_states, attention_mask=None, output_attentions=False):
  497. attn_residual = hidden_states
  498. hidden_states, attn_weights, _ = self.attention(
  499. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  500. )
  501. hidden_states = self.dropout(hidden_states)
  502. hidden_states = attn_residual + hidden_states
  503. hidden_states = self.layer_norm(hidden_states)
  504. hidden_states = hidden_states + self.feed_forward(hidden_states)
  505. hidden_states = self.final_layer_norm(hidden_states)
  506. outputs = (hidden_states,)
  507. if output_attentions:
  508. outputs += (attn_weights,)
  509. return outputs
  510. class Wav2Vec2EncoderLayerStableLayerNorm(GradientCheckpointingLayer):
  511. def __init__(self, config):
  512. super().__init__()
  513. self.attention = Wav2Vec2Attention(
  514. embed_dim=config.hidden_size,
  515. num_heads=config.num_attention_heads,
  516. dropout=config.attention_dropout,
  517. is_decoder=False,
  518. config=config,
  519. )
  520. self.dropout = nn.Dropout(config.hidden_dropout)
  521. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  522. self.feed_forward = Wav2Vec2FeedForward(config)
  523. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  524. if getattr(config, "adapter_attn_dim", None) is not None:
  525. self.adapter_layer = Wav2Vec2AttnAdapterLayer(config)
  526. else:
  527. self.adapter_layer = None
  528. def forward(
  529. self,
  530. hidden_states: torch.Tensor,
  531. attention_mask: Optional[torch.Tensor] = None,
  532. output_attentions: bool = False,
  533. ):
  534. attn_residual = hidden_states
  535. hidden_states = self.layer_norm(hidden_states)
  536. hidden_states, attn_weights, _ = self.attention(
  537. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  538. )
  539. hidden_states = self.dropout(hidden_states)
  540. hidden_states = attn_residual + hidden_states
  541. hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
  542. if self.adapter_layer is not None:
  543. hidden_states = hidden_states + self.adapter_layer(hidden_states)
  544. outputs = (hidden_states,)
  545. if output_attentions:
  546. outputs += (attn_weights,)
  547. return outputs
  548. class Wav2Vec2Encoder(nn.Module):
  549. def __init__(self, config):
  550. super().__init__()
  551. self.config = config
  552. self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
  553. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  554. self.dropout = nn.Dropout(config.hidden_dropout)
  555. self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
  556. self.gradient_checkpointing = False
  557. def forward(
  558. self,
  559. hidden_states: torch.tensor,
  560. attention_mask: Optional[torch.Tensor] = None,
  561. output_attentions: bool = False,
  562. output_hidden_states: bool = False,
  563. return_dict: bool = True,
  564. ):
  565. all_hidden_states = () if output_hidden_states else None
  566. all_self_attentions = () if output_attentions else None
  567. if attention_mask is not None:
  568. # make sure padded tokens output 0
  569. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  570. hidden_states[~expand_attention_mask] = 0
  571. attention_mask = self._update_full_mask(
  572. attention_mask,
  573. hidden_states,
  574. )
  575. position_embeddings = self.pos_conv_embed(hidden_states)
  576. hidden_states = hidden_states + position_embeddings
  577. hidden_states = self.layer_norm(hidden_states)
  578. hidden_states = self.dropout(hidden_states)
  579. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  580. for layer in self.layers:
  581. if output_hidden_states:
  582. all_hidden_states = all_hidden_states + (hidden_states,)
  583. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  584. dropout_probability = torch.rand([])
  585. skip_the_layer = self.training and dropout_probability < self.config.layerdrop
  586. if not skip_the_layer or synced_gpus:
  587. # under fsdp or deepspeed zero3 all gpus must run in sync
  588. layer_outputs = layer(
  589. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  590. )
  591. hidden_states = layer_outputs[0]
  592. if skip_the_layer:
  593. layer_outputs = (None, None)
  594. if output_attentions:
  595. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  596. if output_hidden_states:
  597. all_hidden_states = all_hidden_states + (hidden_states,)
  598. if not return_dict:
  599. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  600. return BaseModelOutput(
  601. last_hidden_state=hidden_states,
  602. hidden_states=all_hidden_states,
  603. attentions=all_self_attentions,
  604. )
  605. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
  606. def _update_full_mask(
  607. self,
  608. attention_mask: Union[torch.Tensor, None],
  609. inputs_embeds: torch.Tensor,
  610. ):
  611. if attention_mask is not None:
  612. if self.config._attn_implementation == "flash_attention_2":
  613. attention_mask = attention_mask if 0 in attention_mask else None
  614. elif self.config._attn_implementation == "sdpa":
  615. # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
  616. # the manual implementation that requires a 4D causal mask in all cases.
  617. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  618. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
  619. elif self.config._attn_implementation == "flex_attention":
  620. if isinstance(attention_mask, torch.Tensor):
  621. attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
  622. else:
  623. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  624. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  625. return attention_mask
  626. class Wav2Vec2EncoderStableLayerNorm(nn.Module):
  627. def __init__(self, config):
  628. super().__init__()
  629. self.config = config
  630. self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
  631. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  632. self.dropout = nn.Dropout(config.hidden_dropout)
  633. self.layers = nn.ModuleList(
  634. [Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
  635. )
  636. self.gradient_checkpointing = False
  637. def forward(
  638. self,
  639. hidden_states,
  640. attention_mask=None,
  641. output_attentions=False,
  642. output_hidden_states=False,
  643. return_dict=True,
  644. ):
  645. all_hidden_states = () if output_hidden_states else None
  646. all_self_attentions = () if output_attentions else None
  647. if attention_mask is not None:
  648. # make sure padded tokens output 0
  649. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  650. hidden_states[~expand_attention_mask] = 0
  651. attention_mask = self._update_full_mask(
  652. attention_mask,
  653. hidden_states,
  654. )
  655. position_embeddings = self.pos_conv_embed(hidden_states)
  656. hidden_states = hidden_states + position_embeddings
  657. hidden_states = self.dropout(hidden_states)
  658. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  659. for layer in self.layers:
  660. if output_hidden_states:
  661. all_hidden_states = all_hidden_states + (hidden_states,)
  662. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  663. dropout_probability = torch.rand([])
  664. skip_the_layer = self.training and dropout_probability < self.config.layerdrop
  665. if not skip_the_layer or synced_gpus:
  666. # under fsdp or deepspeed zero3 all gpus must run in sync
  667. # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
  668. layer_outputs = layer(
  669. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  670. )
  671. hidden_states = layer_outputs[0]
  672. if skip_the_layer:
  673. layer_outputs = (None, None)
  674. if output_attentions:
  675. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  676. hidden_states = self.layer_norm(hidden_states)
  677. if output_hidden_states:
  678. all_hidden_states = all_hidden_states + (hidden_states,)
  679. if not return_dict:
  680. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  681. return BaseModelOutput(
  682. last_hidden_state=hidden_states,
  683. hidden_states=all_hidden_states,
  684. attentions=all_self_attentions,
  685. )
  686. # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
  687. def _update_full_mask(
  688. self,
  689. attention_mask: Union[torch.Tensor, None],
  690. inputs_embeds: torch.Tensor,
  691. ):
  692. if attention_mask is not None:
  693. if self.config._attn_implementation == "flash_attention_2":
  694. attention_mask = attention_mask if 0 in attention_mask else None
  695. elif self.config._attn_implementation == "sdpa":
  696. # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
  697. # the manual implementation that requires a 4D causal mask in all cases.
  698. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  699. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
  700. elif self.config._attn_implementation == "flex_attention":
  701. if isinstance(attention_mask, torch.Tensor):
  702. attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
  703. else:
  704. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  705. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  706. return attention_mask
  707. class Wav2Vec2GumbelVectorQuantizer(nn.Module):
  708. """
  709. Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH
  710. GUMBEL-SOFTMAX](https://huggingface.co/papers/1611.01144) for more information.
  711. """
  712. def __init__(self, config):
  713. super().__init__()
  714. self.num_groups = config.num_codevector_groups
  715. self.num_vars = config.num_codevectors_per_group
  716. if config.codevector_dim % self.num_groups != 0:
  717. raise ValueError(
  718. f"`config.codevector_dim {config.codevector_dim} must be divisible "
  719. f"by `config.num_codevector_groups` {self.num_groups} for concatenation"
  720. )
  721. # storage for codebook variables (codewords)
  722. self.codevectors = nn.Parameter(
  723. torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
  724. )
  725. self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
  726. # can be decayed for training
  727. self.temperature = 2
  728. @staticmethod
  729. def _compute_perplexity(probs, mask=None):
  730. if mask is not None:
  731. mask_extended = mask.flatten()[:, None, None].expand(probs.shape)
  732. probs = torch.where(mask_extended, probs, torch.zeros_like(probs))
  733. marginal_probs = probs.sum(dim=0) / mask.sum()
  734. else:
  735. marginal_probs = probs.mean(dim=0)
  736. perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
  737. return perplexity
  738. def forward(self, hidden_states, mask_time_indices=None):
  739. batch_size, sequence_length, hidden_size = hidden_states.shape
  740. # project to codevector dim
  741. hidden_states = self.weight_proj(hidden_states)
  742. hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
  743. if self.training:
  744. # sample code vector probs via gumbel in differentiateable way
  745. codevector_probs = nn.functional.gumbel_softmax(
  746. hidden_states.float(), tau=self.temperature, hard=True
  747. ).type_as(hidden_states)
  748. # compute perplexity
  749. codevector_soft_dist = torch.softmax(
  750. hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
  751. )
  752. perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices)
  753. else:
  754. # take argmax in non-differentiable way
  755. # comptute hard codevector distribution (one hot)
  756. codevector_idx = hidden_states.argmax(dim=-1)
  757. codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_(
  758. -1, codevector_idx.view(-1, 1), 1.0
  759. )
  760. codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
  761. perplexity = self._compute_perplexity(codevector_probs, mask_time_indices)
  762. codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
  763. # use probs to retrieve codevectors
  764. codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
  765. codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
  766. codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
  767. return codevectors, perplexity
  768. class Wav2Vec2Adapter(nn.Module):
  769. def __init__(self, config):
  770. super().__init__()
  771. # feature dim might need to be down-projected
  772. if config.output_hidden_size != config.hidden_size:
  773. self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
  774. self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
  775. else:
  776. self.proj = self.proj_layer_norm = None
  777. self.layers = nn.ModuleList(Wav2Vec2AdapterLayer(config) for _ in range(config.num_adapter_layers))
  778. self.layerdrop = config.layerdrop
  779. def forward(self, hidden_states):
  780. # down project hidden_states if necessary
  781. if self.proj is not None and self.proj_layer_norm is not None:
  782. hidden_states = self.proj(hidden_states)
  783. hidden_states = self.proj_layer_norm(hidden_states)
  784. hidden_states = hidden_states.transpose(1, 2)
  785. for layer in self.layers:
  786. layerdrop_prob = np.random.random()
  787. if not self.training or (layerdrop_prob > self.layerdrop):
  788. hidden_states = layer(hidden_states)
  789. hidden_states = hidden_states.transpose(1, 2)
  790. return hidden_states
  791. class Wav2Vec2AdapterLayer(nn.Module):
  792. def __init__(self, config):
  793. super().__init__()
  794. self.conv = nn.Conv1d(
  795. config.output_hidden_size,
  796. 2 * config.output_hidden_size,
  797. config.adapter_kernel_size,
  798. stride=config.adapter_stride,
  799. padding=1,
  800. )
  801. def forward(self, hidden_states):
  802. hidden_states = self.conv(hidden_states)
  803. hidden_states = nn.functional.glu(hidden_states, dim=1)
  804. return hidden_states
  805. class Wav2Vec2AttnAdapterLayer(nn.Module):
  806. def __init__(self, config):
  807. """
  808. Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
  809. up training throughput.
  810. """
  811. super().__init__()
  812. self.input_dim = config.adapter_attn_dim
  813. self.hidden_dim = config.hidden_size
  814. self.norm = nn.LayerNorm(self.hidden_dim)
  815. self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
  816. self.act_fn = nn.ReLU()
  817. self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
  818. def forward(self, hidden_states: torch.FloatTensor):
  819. hidden_states = self.norm(hidden_states)
  820. hidden_states = self.linear_1(hidden_states)
  821. hidden_states = self.act_fn(hidden_states)
  822. hidden_states = self.linear_2(hidden_states)
  823. return hidden_states
  824. @auto_docstring
  825. class Wav2Vec2PreTrainedModel(PreTrainedModel):
  826. config: Wav2Vec2Config
  827. base_model_prefix = "wav2vec2"
  828. main_input_name = "input_values"
  829. supports_gradient_checkpointing = True
  830. _supports_flash_attn = True
  831. _supports_sdpa = True
  832. _supports_flex_attn = True
  833. def _init_weights(self, module):
  834. """Initialize the weights"""
  835. # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init.
  836. if isinstance(module, Wav2Vec2ForPreTraining):
  837. module.project_hid.reset_parameters()
  838. module.project_q.reset_parameters()
  839. module.project_hid._is_hf_initialized = True
  840. module.project_q._is_hf_initialized = True
  841. # gumbel softmax requires special init
  842. elif isinstance(module, Wav2Vec2GumbelVectorQuantizer):
  843. module.weight_proj.weight.data.normal_(mean=0.0, std=1)
  844. module.weight_proj.bias.data.zero_()
  845. nn.init.uniform_(module.codevectors)
  846. elif isinstance(module, Wav2Vec2PositionalConvEmbedding):
  847. nn.init.normal_(
  848. module.conv.weight,
  849. mean=0,
  850. std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
  851. )
  852. nn.init.constant_(module.conv.bias, 0)
  853. elif isinstance(module, Wav2Vec2FeatureProjection):
  854. k = math.sqrt(1 / module.projection.in_features)
  855. nn.init.uniform_(module.projection.weight, a=-k, b=k)
  856. nn.init.uniform_(module.projection.bias, a=-k, b=k)
  857. elif isinstance(module, nn.Linear):
  858. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  859. if module.bias is not None:
  860. module.bias.data.zero_()
  861. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  862. module.bias.data.zero_()
  863. module.weight.data.fill_(1.0)
  864. elif isinstance(module, nn.Conv1d):
  865. nn.init.kaiming_normal_(module.weight)
  866. if module.bias is not None:
  867. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  868. nn.init.uniform_(module.bias, a=-k, b=k)
  869. def _get_feat_extract_output_lengths(
  870. self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
  871. ):
  872. """
  873. Computes the output length of the convolutional layers
  874. """
  875. add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
  876. def _conv_out_length(input_length, kernel_size, stride):
  877. # 1D convolutional layer output length formula taken
  878. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  879. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  880. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  881. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  882. if add_adapter:
  883. for _ in range(self.config.num_adapter_layers):
  884. input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
  885. return input_lengths
  886. def _get_feature_vector_attention_mask(
  887. self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
  888. ):
  889. # Effectively attention_mask.sum(-1), but not inplace to be able to run
  890. # on inference mode.
  891. non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
  892. output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
  893. output_lengths = output_lengths.to(torch.long)
  894. batch_size = attention_mask.shape[0]
  895. attention_mask = torch.zeros(
  896. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  897. )
  898. # these two operations makes sure that all values before the output lengths idxs are attended to
  899. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  900. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  901. return attention_mask
  902. def _get_adapters(self):
  903. if self.config.adapter_attn_dim is None:
  904. raise ValueError(f"{self.__class__} has no adapter layers. Make sure to define `config.adapter_attn_dim`.")
  905. adapter_weights = {}
  906. for name, module in self.named_modules():
  907. if isinstance(module, Wav2Vec2AttnAdapterLayer):
  908. for param_name, param in module.named_parameters():
  909. adapter_weights[".".join([name, param_name])] = param
  910. if isinstance(self, Wav2Vec2ForCTC):
  911. for name, param in self.lm_head.named_parameters():
  912. adapter_weights[".".join(["lm_head", name])] = param
  913. return adapter_weights
  914. def init_adapter_layers(self):
  915. """
  916. (Re-)initialize attention adapter layers and lm head for adapter-only fine-tuning
  917. """
  918. # init attention adapters
  919. for module in self.modules():
  920. if isinstance(module, Wav2Vec2AttnAdapterLayer):
  921. self._init_weights(module)
  922. # init lm head
  923. if isinstance(self, Wav2Vec2ForCTC):
  924. self._init_weights(self.lm_head)
  925. def load_adapter(self, target_lang: str, force_load=True, **kwargs):
  926. r"""
  927. Load a language adapter model from a pre-trained adapter model.
  928. Parameters:
  929. target_lang (`str`):
  930. Has to be a language id of an existing adapter weight. Adapter weights are stored in the format
  931. adapter.<lang>.safetensors or adapter.<lang>.bin
  932. force_load (`bool`, defaults to `True`):
  933. Whether the weights shall be loaded even if `target_lang` matches `self.target_lang`.
  934. cache_dir (`Union[str, os.PathLike]`, *optional*):
  935. Path to a directory in which a downloaded pretrained model configuration should be cached if the
  936. standard cache should not be used.
  937. force_download (`bool`, *optional*, defaults to `False`):
  938. Whether or not to force the (re-)download of the model weights and configuration files, overriding the
  939. cached versions if they exist.
  940. resume_download:
  941. Deprecated and ignored. All downloads are now resumed by default when possible.
  942. Will be removed in v5 of Transformers.
  943. proxies (`dict[str, str]`, *optional*):
  944. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  945. 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
  946. local_files_only(`bool`, *optional*, defaults to `False`):
  947. Whether or not to only look at local files (i.e., do not try to download the model).
  948. token (`str` or `bool`, *optional*):
  949. The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
  950. the token generated when running `hf auth login` (stored in `~/.huggingface`).
  951. revision (`str`, *optional*, defaults to `"main"`):
  952. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  953. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  954. identifier allowed by git.
  955. <Tip>
  956. To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
  957. </Tip>
  958. mirror (`str`, *optional*):
  959. Mirror source to accelerate downloads in China. If you are from China and have an accessibility
  960. problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
  961. Please refer to the mirror site for more information.
  962. <Tip>
  963. Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
  964. use this method in a firewalled environment.
  965. </Tip>
  966. Examples:
  967. ```python
  968. >>> from transformers import Wav2Vec2ForCTC, AutoProcessor
  969. >>> ckpt = "facebook/mms-1b-all"
  970. >>> processor = AutoProcessor.from_pretrained(ckpt)
  971. >>> model = Wav2Vec2ForCTC.from_pretrained(ckpt, target_lang="eng")
  972. >>> # set specific language
  973. >>> processor.tokenizer.set_target_lang("spa")
  974. >>> model.load_adapter("spa")
  975. ```
  976. """
  977. if self.config.adapter_attn_dim is None:
  978. raise ValueError(f"Cannot load_adapter for {target_lang} if `config.adapter_attn_dim` is not defined.")
  979. if target_lang == self.target_lang and not force_load:
  980. logger.warning(f"Adapter weights are already set to {target_lang}.")
  981. return
  982. cache_dir = kwargs.pop("cache_dir", None)
  983. force_download = kwargs.pop("force_download", False)
  984. resume_download = kwargs.pop("resume_download", None)
  985. proxies = kwargs.pop("proxies", None)
  986. local_files_only = kwargs.pop("local_files_only", False)
  987. token = kwargs.pop("token", None)
  988. use_auth_token = kwargs.pop("use_auth_token", None)
  989. revision = kwargs.pop("revision", None)
  990. use_safetensors = kwargs.pop("use_safetensors", None)
  991. if use_auth_token is not None:
  992. warnings.warn(
  993. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  994. FutureWarning,
  995. )
  996. if token is not None:
  997. raise ValueError(
  998. "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
  999. )
  1000. token = use_auth_token
  1001. model_path_or_id = self.config._name_or_path
  1002. state_dict = None
  1003. # 1. Let's first try loading a safetensors adapter weight
  1004. if use_safetensors is not False:
  1005. filepath = WAV2VEC2_ADAPTER_SAFE_FILE.format(target_lang)
  1006. try:
  1007. weight_path = cached_file(
  1008. model_path_or_id,
  1009. filename=filepath,
  1010. force_download=force_download,
  1011. resume_download=resume_download,
  1012. proxies=proxies,
  1013. local_files_only=local_files_only,
  1014. token=token,
  1015. revision=revision,
  1016. cache_dir=cache_dir,
  1017. )
  1018. state_dict = safe_load_file(weight_path)
  1019. except OSError:
  1020. if use_safetensors:
  1021. # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
  1022. # to the original exception.
  1023. raise
  1024. except Exception:
  1025. # For any other exception, we throw a generic error.
  1026. if use_safetensors:
  1027. raise OSError(
  1028. f"Can't load the model for '{model_path_or_id}'. If you were trying to load it"
  1029. " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
  1030. f" same name. Otherwise, make sure '{model_path_or_id}' is the correct path to a"
  1031. f" directory containing a file named {filepath}."
  1032. )
  1033. # 2. If this didn't work let's try loading a PyTorch adapter weight
  1034. if state_dict is None:
  1035. filepath = WAV2VEC2_ADAPTER_PT_FILE.format(target_lang)
  1036. try:
  1037. weight_path = cached_file(
  1038. model_path_or_id,
  1039. filename=filepath,
  1040. force_download=force_download,
  1041. resume_download=resume_download,
  1042. proxies=proxies,
  1043. local_files_only=local_files_only,
  1044. token=token,
  1045. revision=revision,
  1046. cache_dir=cache_dir,
  1047. )
  1048. check_torch_load_is_safe()
  1049. state_dict = torch.load(
  1050. weight_path,
  1051. map_location="cpu",
  1052. weights_only=True,
  1053. )
  1054. except OSError:
  1055. # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
  1056. # to the original exception.
  1057. raise
  1058. except ValueError:
  1059. raise
  1060. except Exception:
  1061. # For any other exception, we throw a generic error.
  1062. raise OSError(
  1063. f"Can't load the model for '{model_path_or_id}'. If you were trying to load it"
  1064. " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
  1065. f" same name. Otherwise, make sure '{model_path_or_id}' is the correct path to a"
  1066. f" directory containing a file named {filepath}."
  1067. )
  1068. adapter_weights = self._get_adapters()
  1069. unexpected_keys = set(state_dict.keys()) - set(adapter_weights.keys())
  1070. missing_keys = set(adapter_weights.keys()) - set(state_dict.keys())
  1071. if len(unexpected_keys) > 0:
  1072. raise ValueError(f"The adapter weights {weight_path} has unexpected keys: {', '.join(unexpected_keys)}.")
  1073. elif len(missing_keys) > 0:
  1074. raise ValueError(f"The adapter weights {weight_path} has missing keys: {', '.join(missing_keys)}.")
  1075. # make sure now vocab size is correct
  1076. target_vocab_size = state_dict["lm_head.weight"].shape[0]
  1077. if target_vocab_size != self.config.vocab_size:
  1078. self.lm_head = nn.Linear(
  1079. self.config.output_hidden_size, target_vocab_size, device=self.device, dtype=self.dtype
  1080. )
  1081. self.config.vocab_size = target_vocab_size
  1082. # make sure that adapter weights are put in exactly the same precision and device placement and overwritten adapter weights
  1083. state_dict = {k: v.to(adapter_weights[k]) for k, v in state_dict.items()}
  1084. self.load_state_dict(state_dict, strict=False)
  1085. # set target language correctly
  1086. self.target_lang = target_lang
  1087. @auto_docstring
  1088. class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
  1089. def __init__(self, config: Wav2Vec2Config):
  1090. super().__init__(config)
  1091. self.config = config
  1092. self.feature_extractor = Wav2Vec2FeatureEncoder(config)
  1093. self.feature_projection = Wav2Vec2FeatureProjection(config)
  1094. # model only needs masking vector if mask prob is > 0.0
  1095. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  1096. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  1097. if config.do_stable_layer_norm:
  1098. self.encoder = Wav2Vec2EncoderStableLayerNorm(config)
  1099. else:
  1100. self.encoder = Wav2Vec2Encoder(config)
  1101. self.adapter = Wav2Vec2Adapter(config) if config.add_adapter else None
  1102. # Initialize weights and apply final processing
  1103. self.post_init()
  1104. def freeze_feature_extractor(self):
  1105. """
  1106. Calling this function will disable the gradient computation for the feature encoder so that its parameters will
  1107. not be updated during training.
  1108. """
  1109. warnings.warn(
  1110. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1111. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1112. FutureWarning,
  1113. )
  1114. self.freeze_feature_encoder()
  1115. def freeze_feature_encoder(self):
  1116. """
  1117. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1118. not be updated during training.
  1119. """
  1120. self.feature_extractor._freeze_parameters()
  1121. def _mask_hidden_states(
  1122. self,
  1123. hidden_states: torch.FloatTensor,
  1124. mask_time_indices: Optional[torch.FloatTensor] = None,
  1125. attention_mask: Optional[torch.LongTensor] = None,
  1126. ):
  1127. """
  1128. Masks extracted features along time axis and/or along feature axis according to
  1129. [SpecAugment](https://huggingface.co/papers/1904.08779).
  1130. """
  1131. # `config.apply_spec_augment` can set masking to False
  1132. if not getattr(self.config, "apply_spec_augment", True):
  1133. return hidden_states
  1134. # generate indices & apply SpecAugment along time axis
  1135. batch_size, sequence_length, hidden_size = hidden_states.size()
  1136. if mask_time_indices is not None:
  1137. # apply SpecAugment along time axis with given mask_time_indices
  1138. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  1139. elif self.config.mask_time_prob > 0 and self.training:
  1140. mask_time_indices = _compute_mask_indices(
  1141. (batch_size, sequence_length),
  1142. mask_prob=self.config.mask_time_prob,
  1143. mask_length=self.config.mask_time_length,
  1144. attention_mask=attention_mask,
  1145. min_masks=self.config.mask_time_min_masks,
  1146. )
  1147. mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
  1148. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  1149. if self.config.mask_feature_prob > 0 and self.training:
  1150. # generate indices & apply SpecAugment along feature axis
  1151. mask_feature_indices = _compute_mask_indices(
  1152. (batch_size, hidden_size),
  1153. mask_prob=self.config.mask_feature_prob,
  1154. mask_length=self.config.mask_feature_length,
  1155. min_masks=self.config.mask_feature_min_masks,
  1156. )
  1157. mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
  1158. mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
  1159. hidden_states[mask_feature_indices] = 0
  1160. return hidden_states
  1161. @auto_docstring
  1162. def forward(
  1163. self,
  1164. input_values: Optional[torch.Tensor],
  1165. attention_mask: Optional[torch.Tensor] = None,
  1166. mask_time_indices: Optional[torch.FloatTensor] = None,
  1167. output_attentions: Optional[bool] = None,
  1168. output_hidden_states: Optional[bool] = None,
  1169. return_dict: Optional[bool] = None,
  1170. ) -> Union[tuple, Wav2Vec2BaseModelOutput]:
  1171. r"""
  1172. mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1173. Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
  1174. masked extracted features in *config.proj_codevector_dim* space.
  1175. """
  1176. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1177. output_hidden_states = (
  1178. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1179. )
  1180. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1181. extract_features = self.feature_extractor(input_values)
  1182. extract_features = extract_features.transpose(1, 2)
  1183. if attention_mask is not None:
  1184. # compute reduced attention_mask corresponding to feature vectors
  1185. attention_mask = self._get_feature_vector_attention_mask(
  1186. extract_features.shape[1], attention_mask, add_adapter=False
  1187. )
  1188. hidden_states, extract_features = self.feature_projection(extract_features)
  1189. hidden_states = self._mask_hidden_states(
  1190. hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
  1191. )
  1192. encoder_outputs = self.encoder(
  1193. hidden_states,
  1194. attention_mask=attention_mask,
  1195. output_attentions=output_attentions,
  1196. output_hidden_states=output_hidden_states,
  1197. return_dict=return_dict,
  1198. )
  1199. hidden_states = encoder_outputs[0]
  1200. if self.adapter is not None:
  1201. hidden_states = self.adapter(hidden_states)
  1202. if not return_dict:
  1203. return (hidden_states, extract_features) + encoder_outputs[1:]
  1204. return Wav2Vec2BaseModelOutput(
  1205. last_hidden_state=hidden_states,
  1206. extract_features=extract_features,
  1207. hidden_states=encoder_outputs.hidden_states,
  1208. attentions=encoder_outputs.attentions,
  1209. )
  1210. @auto_docstring(
  1211. custom_intro="""
  1212. Wav2Vec2 Model with a quantizer and `VQ` head on top.
  1213. """
  1214. )
  1215. class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
  1216. def __init__(self, config: Wav2Vec2Config):
  1217. super().__init__(config)
  1218. self.wav2vec2 = Wav2Vec2Model(config)
  1219. self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
  1220. self.quantizer = Wav2Vec2GumbelVectorQuantizer(config)
  1221. self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim)
  1222. self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
  1223. # Initialize weights and apply final processing
  1224. self.post_init()
  1225. def set_gumbel_temperature(self, temperature: int):
  1226. """
  1227. Set the Gumbel softmax temperature to a given value. Only necessary for training
  1228. """
  1229. self.quantizer.temperature = temperature
  1230. def freeze_feature_extractor(self):
  1231. """
  1232. Calling this function will disable the gradient computation for the feature encoder so that its parameters will
  1233. not be updated during training.
  1234. """
  1235. warnings.warn(
  1236. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1237. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1238. FutureWarning,
  1239. )
  1240. self.freeze_feature_encoder()
  1241. def freeze_feature_encoder(self):
  1242. """
  1243. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1244. not be updated during training.
  1245. """
  1246. self.wav2vec2.feature_extractor._freeze_parameters()
  1247. @staticmethod
  1248. def compute_contrastive_logits(
  1249. target_features: torch.FloatTensor,
  1250. negative_features: torch.FloatTensor,
  1251. predicted_features: torch.FloatTensor,
  1252. temperature: int = 0.1,
  1253. ):
  1254. """
  1255. Compute logits for contrastive loss based using cosine similarity as the distance measure between
  1256. `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
  1257. """
  1258. target_features = torch.cat([target_features, negative_features], dim=0)
  1259. logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as(
  1260. target_features
  1261. )
  1262. # apply temperature
  1263. logits = logits / temperature
  1264. return logits
  1265. @auto_docstring
  1266. def forward(
  1267. self,
  1268. input_values: Optional[torch.Tensor],
  1269. attention_mask: Optional[torch.Tensor] = None,
  1270. mask_time_indices: Optional[torch.BoolTensor] = None,
  1271. sampled_negative_indices: Optional[torch.BoolTensor] = None,
  1272. output_attentions: Optional[bool] = None,
  1273. output_hidden_states: Optional[bool] = None,
  1274. return_dict: Optional[bool] = None,
  1275. ) -> Union[tuple, Wav2Vec2ForPreTrainingOutput]:
  1276. r"""
  1277. mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1278. Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
  1279. masked extracted features in *config.proj_codevector_dim* space.
  1280. sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*):
  1281. Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss.
  1282. Required input for pre-training.
  1283. Example:
  1284. ```python
  1285. >>> import torch
  1286. >>> from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining
  1287. >>> from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
  1288. >>> from datasets import load_dataset
  1289. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
  1290. >>> model = Wav2Vec2ForPreTraining.from_pretrained("facebook/wav2vec2-base")
  1291. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  1292. >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
  1293. >>> # compute masked indices
  1294. >>> batch_size, raw_sequence_length = input_values.shape
  1295. >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()
  1296. >>> mask_time_indices = _compute_mask_indices(
  1297. ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2
  1298. ... )
  1299. >>> sampled_negative_indices = _sample_negative_indices(
  1300. ... features_shape=(batch_size, sequence_length),
  1301. ... num_negatives=model.config.num_negatives,
  1302. ... mask_time_indices=mask_time_indices,
  1303. ... )
  1304. >>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long)
  1305. >>> sampled_negative_indices = torch.tensor(
  1306. ... data=sampled_negative_indices, device=input_values.device, dtype=torch.long
  1307. ... )
  1308. >>> with torch.no_grad():
  1309. ... outputs = model(input_values, mask_time_indices=mask_time_indices)
  1310. >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states)
  1311. >>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1)
  1312. >>> # show that cosine similarity is much higher than random
  1313. >>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5
  1314. tensor(True)
  1315. >>> # for contrastive loss training model should be put into train mode
  1316. >>> model = model.train()
  1317. >>> loss = model(
  1318. ... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
  1319. ... ).loss
  1320. ```"""
  1321. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1322. if mask_time_indices is not None:
  1323. mask_time_indices = mask_time_indices.to(torch.bool)
  1324. outputs = self.wav2vec2(
  1325. input_values,
  1326. attention_mask=attention_mask,
  1327. output_attentions=output_attentions,
  1328. output_hidden_states=output_hidden_states,
  1329. mask_time_indices=mask_time_indices,
  1330. return_dict=return_dict,
  1331. )
  1332. # 1. project all transformed features (including masked) to final vq dim
  1333. transformer_features = self.project_hid(outputs[0])
  1334. # 2. quantize all (unmasked) extracted features and project to final vq dim
  1335. extract_features = self.dropout_features(outputs[1])
  1336. if attention_mask is not None:
  1337. # compute reduced attention_mask corresponding to feature vectors
  1338. attention_mask = self._get_feature_vector_attention_mask(
  1339. extract_features.shape[1], attention_mask, add_adapter=False
  1340. )
  1341. quantized_features, codevector_perplexity = self.quantizer(
  1342. extract_features, mask_time_indices=mask_time_indices
  1343. )
  1344. quantized_features = quantized_features.to(self.project_q.weight.dtype)
  1345. quantized_features = self.project_q(quantized_features)
  1346. loss = contrastive_loss = diversity_loss = None
  1347. if sampled_negative_indices is not None:
  1348. batch_size, sequence_length, hidden_size = quantized_features.shape
  1349. # for training, we sample negatives
  1350. # 3. sample K negatives (distractors) quantized states for contrastive loss
  1351. # if attention_mask is passed, make sure that padded feature vectors cannot be sampled
  1352. # sample negative quantized vectors BTC => (BxT)C
  1353. negative_quantized_features = quantized_features.view(-1, hidden_size)[
  1354. sampled_negative_indices.long().view(-1)
  1355. ]
  1356. negative_quantized_features = negative_quantized_features.view(
  1357. batch_size, sequence_length, -1, hidden_size
  1358. ).permute(2, 0, 1, 3)
  1359. # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
  1360. # of equation (3) in https://huggingface.co/papers/2006.11477
  1361. logits = self.compute_contrastive_logits(
  1362. quantized_features[None, :],
  1363. negative_quantized_features,
  1364. transformer_features,
  1365. self.config.contrastive_logits_temperature,
  1366. )
  1367. # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low),
  1368. # its cosine similarity will be masked
  1369. neg_is_pos = (quantized_features == negative_quantized_features).all(-1)
  1370. if neg_is_pos.any():
  1371. logits[1:][neg_is_pos] = float("-inf")
  1372. # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) =
  1373. # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa))
  1374. logits = logits.transpose(0, 2).reshape(-1, logits.size(0))
  1375. target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten()
  1376. contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum")
  1377. # 7. compute diversity loss: \mathbf{L}_d
  1378. num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups
  1379. diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum()
  1380. # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d
  1381. loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss
  1382. if not return_dict:
  1383. if loss is not None:
  1384. return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
  1385. return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
  1386. return Wav2Vec2ForPreTrainingOutput(
  1387. loss=loss,
  1388. projected_states=transformer_features,
  1389. projected_quantized_states=quantized_features,
  1390. codevector_perplexity=codevector_perplexity,
  1391. hidden_states=outputs.hidden_states,
  1392. attentions=outputs.attentions,
  1393. contrastive_loss=contrastive_loss,
  1394. diversity_loss=diversity_loss,
  1395. )
  1396. @auto_docstring
  1397. class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
  1398. def __init__(self, config):
  1399. super().__init__(config)
  1400. warnings.warn(
  1401. "The class `Wav2Vec2ForMaskedLM` is deprecated. Please use `Wav2Vec2ForCTC` instead.", FutureWarning
  1402. )
  1403. self.wav2vec2 = Wav2Vec2Model(config)
  1404. self.dropout = nn.Dropout(config.final_dropout)
  1405. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
  1406. # Initialize weights and apply final processing
  1407. self.post_init()
  1408. @auto_docstring
  1409. def forward(
  1410. self,
  1411. input_values: torch.FloatTensor,
  1412. attention_mask: Optional[torch.LongTensor] = None,
  1413. output_attentions: Optional[bool] = None,
  1414. output_hidden_states: Optional[bool] = None,
  1415. return_dict: Optional[bool] = None,
  1416. labels: Optional[torch.Tensor] = None,
  1417. ) -> Union[tuple, MaskedLMOutput]:
  1418. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1419. outputs = self.wav2vec2(
  1420. input_values,
  1421. output_attentions=output_attentions,
  1422. output_hidden_states=output_hidden_states,
  1423. return_dict=return_dict,
  1424. )
  1425. hidden_states = outputs[0]
  1426. hidden_states = self.dropout(hidden_states)
  1427. logits = self.lm_head(hidden_states)
  1428. if not return_dict:
  1429. output = (logits,) + outputs[2:]
  1430. return output
  1431. return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
  1432. @auto_docstring(
  1433. custom_intro="""
  1434. Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
  1435. """
  1436. )
  1437. class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
  1438. def __init__(self, config, target_lang: Optional[str] = None):
  1439. r"""
  1440. target_lang (`str`, *optional*):
  1441. Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
  1442. adapter.<lang>.bin. Only relevant when using an instance of [`Wav2Vec2ForCTC`] with adapters. Uses 'eng' by
  1443. default.
  1444. """
  1445. super().__init__(config)
  1446. self.wav2vec2 = Wav2Vec2Model(config)
  1447. self.dropout = nn.Dropout(config.final_dropout)
  1448. self.target_lang = target_lang
  1449. if config.vocab_size is None:
  1450. raise ValueError(
  1451. f"You are trying to instantiate {self.__class__} with a configuration that "
  1452. "does not define the vocabulary size of the language model head. Please "
  1453. "instantiate the model as follows: `Wav2Vec2ForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
  1454. "or define `vocab_size` of your model's configuration."
  1455. )
  1456. output_hidden_size = (
  1457. config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
  1458. )
  1459. self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
  1460. # Initialize weights and apply final processing
  1461. self.post_init()
  1462. def tie_weights(self):
  1463. """
  1464. This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
  1465. passing `target_lang=...` to `from_pretrained(...)`.
  1466. This method is **not** supposed to be called by the user and is prone to be changed in the future.
  1467. """
  1468. # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
  1469. # correctly load adapter layers for Wav2Vec2 so that we do not have to introduce a new API to
  1470. # [`PreTrainedModel`]. While slightly hacky, Wav2Vec2 never has to tie input and output embeddings, so that it is
  1471. # ok to repurpose this function here.
  1472. target_lang = self.target_lang
  1473. if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
  1474. raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
  1475. elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
  1476. logger.info("By default `target_lang` is set to 'eng'.")
  1477. elif target_lang is not None:
  1478. self.load_adapter(target_lang, force_load=True)
  1479. def freeze_feature_extractor(self):
  1480. """
  1481. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1482. not be updated during training.
  1483. """
  1484. warnings.warn(
  1485. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1486. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1487. FutureWarning,
  1488. )
  1489. self.freeze_feature_encoder()
  1490. def freeze_feature_encoder(self):
  1491. """
  1492. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1493. not be updated during training.
  1494. """
  1495. self.wav2vec2.feature_extractor._freeze_parameters()
  1496. def freeze_base_model(self):
  1497. """
  1498. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1499. be updated during training. Only the classification head will be updated.
  1500. """
  1501. for param in self.wav2vec2.parameters():
  1502. param.requires_grad = False
  1503. @auto_docstring
  1504. def forward(
  1505. self,
  1506. input_values: Optional[torch.Tensor],
  1507. attention_mask: Optional[torch.Tensor] = None,
  1508. output_attentions: Optional[bool] = None,
  1509. output_hidden_states: Optional[bool] = None,
  1510. return_dict: Optional[bool] = None,
  1511. labels: Optional[torch.Tensor] = None,
  1512. ) -> Union[tuple, CausalLMOutput]:
  1513. r"""
  1514. labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
  1515. Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
  1516. the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
  1517. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
  1518. config.vocab_size - 1]`.
  1519. """
  1520. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1521. if labels is not None and labels.max() >= self.config.vocab_size:
  1522. raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
  1523. outputs = self.wav2vec2(
  1524. input_values,
  1525. attention_mask=attention_mask,
  1526. output_attentions=output_attentions,
  1527. output_hidden_states=output_hidden_states,
  1528. return_dict=return_dict,
  1529. )
  1530. hidden_states = outputs[0]
  1531. hidden_states = self.dropout(hidden_states)
  1532. logits = self.lm_head(hidden_states)
  1533. loss = None
  1534. if labels is not None:
  1535. # retrieve loss input_lengths from attention_mask
  1536. attention_mask = (
  1537. attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
  1538. )
  1539. input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  1540. # assuming that padded tokens are filled with -100
  1541. # when not being attended to
  1542. labels_mask = labels >= 0
  1543. target_lengths = labels_mask.sum(-1)
  1544. flattened_targets = labels.masked_select(labels_mask)
  1545. # ctc_loss doesn't support fp16
  1546. log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
  1547. with torch.backends.cudnn.flags(enabled=False):
  1548. loss = nn.functional.ctc_loss(
  1549. log_probs,
  1550. flattened_targets,
  1551. input_lengths,
  1552. target_lengths,
  1553. blank=self.config.pad_token_id,
  1554. reduction=self.config.ctc_loss_reduction,
  1555. zero_infinity=self.config.ctc_zero_infinity,
  1556. )
  1557. if not return_dict:
  1558. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1559. return ((loss,) + output) if loss is not None else output
  1560. return CausalLMOutput(
  1561. loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  1562. )
  1563. @auto_docstring(
  1564. custom_intro="""
  1565. Wav2Vec2 Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
  1566. SUPERB Keyword Spotting.
  1567. """
  1568. )
  1569. class Wav2Vec2ForSequenceClassification(Wav2Vec2PreTrainedModel):
  1570. def __init__(self, config):
  1571. super().__init__(config)
  1572. if hasattr(config, "add_adapter") and config.add_adapter:
  1573. raise ValueError(
  1574. "Sequence classification does not support the use of Wav2Vec2 adapters (config.add_adapter=True)"
  1575. )
  1576. self.wav2vec2 = Wav2Vec2Model(config)
  1577. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1578. if config.use_weighted_layer_sum:
  1579. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1580. self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
  1581. self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
  1582. # Initialize weights and apply final processing
  1583. self.post_init()
  1584. def freeze_feature_extractor(self):
  1585. """
  1586. Calling this function will disable the gradient computation for the feature encoder so that its parameters will
  1587. not be updated during training.
  1588. """
  1589. warnings.warn(
  1590. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1591. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1592. FutureWarning,
  1593. )
  1594. self.freeze_feature_encoder()
  1595. def freeze_feature_encoder(self):
  1596. """
  1597. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1598. not be updated during training.
  1599. """
  1600. self.wav2vec2.feature_extractor._freeze_parameters()
  1601. def freeze_base_model(self):
  1602. """
  1603. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1604. be updated during training. Only the classification head will be updated.
  1605. """
  1606. for param in self.wav2vec2.parameters():
  1607. param.requires_grad = False
  1608. @auto_docstring
  1609. def forward(
  1610. self,
  1611. input_values: Optional[torch.Tensor],
  1612. attention_mask: Optional[torch.Tensor] = None,
  1613. output_attentions: Optional[bool] = None,
  1614. output_hidden_states: Optional[bool] = None,
  1615. return_dict: Optional[bool] = None,
  1616. labels: Optional[torch.Tensor] = None,
  1617. ) -> Union[tuple, SequenceClassifierOutput]:
  1618. r"""
  1619. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1620. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  1621. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1622. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1623. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  1624. into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
  1625. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1626. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1627. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1628. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1629. """
  1630. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1631. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1632. outputs = self.wav2vec2(
  1633. input_values,
  1634. attention_mask=attention_mask,
  1635. output_attentions=output_attentions,
  1636. output_hidden_states=output_hidden_states,
  1637. return_dict=return_dict,
  1638. )
  1639. if self.config.use_weighted_layer_sum:
  1640. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1641. hidden_states = torch.stack(hidden_states, dim=1)
  1642. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1643. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1644. else:
  1645. hidden_states = outputs[0]
  1646. hidden_states = self.projector(hidden_states)
  1647. if attention_mask is None:
  1648. pooled_output = hidden_states.mean(dim=1)
  1649. else:
  1650. padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
  1651. expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  1652. hidden_states[~expand_padding_mask] = 0.0
  1653. pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
  1654. logits = self.classifier(pooled_output)
  1655. loss = None
  1656. if labels is not None:
  1657. loss_fct = CrossEntropyLoss()
  1658. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1659. if not return_dict:
  1660. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1661. return ((loss,) + output) if loss is not None else output
  1662. return SequenceClassifierOutput(
  1663. loss=loss,
  1664. logits=logits,
  1665. hidden_states=outputs.hidden_states,
  1666. attentions=outputs.attentions,
  1667. )
  1668. @auto_docstring
  1669. class Wav2Vec2ForAudioFrameClassification(Wav2Vec2PreTrainedModel):
  1670. def __init__(self, config):
  1671. super().__init__(config)
  1672. if hasattr(config, "add_adapter") and config.add_adapter:
  1673. raise ValueError(
  1674. "Audio frame classification does not support the use of Wav2Vec2 adapters (config.add_adapter=True)"
  1675. )
  1676. self.wav2vec2 = Wav2Vec2Model(config)
  1677. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1678. if config.use_weighted_layer_sum:
  1679. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1680. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1681. self.num_labels = config.num_labels
  1682. self.init_weights()
  1683. def freeze_feature_extractor(self):
  1684. """
  1685. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1686. not be updated during training.
  1687. """
  1688. warnings.warn(
  1689. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1690. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1691. FutureWarning,
  1692. )
  1693. self.freeze_feature_encoder()
  1694. def freeze_feature_encoder(self):
  1695. """
  1696. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1697. not be updated during training.
  1698. """
  1699. self.wav2vec2.feature_extractor._freeze_parameters()
  1700. def freeze_base_model(self):
  1701. """
  1702. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1703. be updated during training. Only the classification head will be updated.
  1704. """
  1705. for param in self.wav2vec2.parameters():
  1706. param.requires_grad = False
  1707. @auto_docstring
  1708. def forward(
  1709. self,
  1710. input_values: Optional[torch.Tensor],
  1711. attention_mask: Optional[torch.Tensor] = None,
  1712. labels: Optional[torch.Tensor] = None,
  1713. output_attentions: Optional[bool] = None,
  1714. output_hidden_states: Optional[bool] = None,
  1715. return_dict: Optional[bool] = None,
  1716. ) -> Union[tuple, TokenClassifierOutput]:
  1717. r"""
  1718. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1719. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  1720. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1721. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1722. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  1723. into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
  1724. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1725. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1726. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1727. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1728. """
  1729. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1730. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1731. outputs = self.wav2vec2(
  1732. input_values,
  1733. attention_mask=attention_mask,
  1734. output_attentions=output_attentions,
  1735. output_hidden_states=output_hidden_states,
  1736. return_dict=return_dict,
  1737. )
  1738. if self.config.use_weighted_layer_sum:
  1739. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1740. hidden_states = torch.stack(hidden_states, dim=1)
  1741. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1742. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1743. else:
  1744. hidden_states = outputs[0]
  1745. logits = self.classifier(hidden_states)
  1746. loss = None
  1747. if labels is not None:
  1748. loss_fct = CrossEntropyLoss()
  1749. loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
  1750. if not return_dict:
  1751. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1752. return output
  1753. return TokenClassifierOutput(
  1754. loss=loss,
  1755. logits=logits,
  1756. hidden_states=outputs.hidden_states,
  1757. attentions=outputs.attentions,
  1758. )
  1759. class AMSoftmaxLoss(nn.Module):
  1760. def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
  1761. super().__init__()
  1762. self.scale = scale
  1763. self.margin = margin
  1764. self.num_labels = num_labels
  1765. self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
  1766. self.loss = nn.CrossEntropyLoss()
  1767. def forward(self, hidden_states, labels):
  1768. labels = labels.flatten()
  1769. weight = nn.functional.normalize(self.weight, dim=0)
  1770. hidden_states = nn.functional.normalize(hidden_states, dim=1)
  1771. cos_theta = torch.mm(hidden_states, weight)
  1772. psi = cos_theta - self.margin
  1773. onehot = nn.functional.one_hot(labels, self.num_labels)
  1774. logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
  1775. loss = self.loss(logits, labels)
  1776. return loss
  1777. class TDNNLayer(nn.Module):
  1778. def __init__(self, config, layer_id=0):
  1779. super().__init__()
  1780. self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
  1781. self.out_conv_dim = config.tdnn_dim[layer_id]
  1782. self.kernel_size = config.tdnn_kernel[layer_id]
  1783. self.dilation = config.tdnn_dilation[layer_id]
  1784. self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
  1785. self.activation = nn.ReLU()
  1786. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  1787. if is_peft_available():
  1788. from peft.tuners.lora import LoraLayer
  1789. if is_peft_available():
  1790. if isinstance(self.kernel, LoraLayer):
  1791. warnings.warn(
  1792. "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
  1793. "You should exclude TDNNLayer from LoRA's target modules.",
  1794. )
  1795. # for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
  1796. hidden_states = hidden_states.transpose(1, 2)
  1797. weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
  1798. hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
  1799. hidden_states = hidden_states.transpose(1, 2)
  1800. hidden_states = self.activation(hidden_states)
  1801. return hidden_states
  1802. @auto_docstring(
  1803. custom_intro="""
  1804. Wav2Vec2 Model with an XVector feature extraction head on top for tasks like Speaker Verification.
  1805. """
  1806. )
  1807. class Wav2Vec2ForXVector(Wav2Vec2PreTrainedModel):
  1808. def __init__(self, config):
  1809. super().__init__(config)
  1810. self.wav2vec2 = Wav2Vec2Model(config)
  1811. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1812. if config.use_weighted_layer_sum:
  1813. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1814. self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
  1815. tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
  1816. self.tdnn = nn.ModuleList(tdnn_layers)
  1817. self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
  1818. self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
  1819. self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
  1820. self.init_weights()
  1821. def freeze_feature_extractor(self):
  1822. """
  1823. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1824. not be updated during training.
  1825. """
  1826. warnings.warn(
  1827. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1828. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1829. FutureWarning,
  1830. )
  1831. self.freeze_feature_encoder()
  1832. def freeze_feature_encoder(self):
  1833. """
  1834. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1835. not be updated during training.
  1836. """
  1837. self.wav2vec2.feature_extractor._freeze_parameters()
  1838. def freeze_base_model(self):
  1839. """
  1840. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1841. be updated during training. Only the classification head will be updated.
  1842. """
  1843. for param in self.wav2vec2.parameters():
  1844. param.requires_grad = False
  1845. def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
  1846. """
  1847. Computes the output length of the TDNN layers
  1848. """
  1849. def _conv_out_length(input_length, kernel_size, stride):
  1850. # 1D convolutional layer output length formula taken
  1851. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  1852. return (input_length - kernel_size) // stride + 1
  1853. for kernel_size in self.config.tdnn_kernel:
  1854. input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
  1855. return input_lengths
  1856. @auto_docstring
  1857. def forward(
  1858. self,
  1859. input_values: Optional[torch.Tensor],
  1860. attention_mask: Optional[torch.Tensor] = None,
  1861. output_attentions: Optional[bool] = None,
  1862. output_hidden_states: Optional[bool] = None,
  1863. return_dict: Optional[bool] = None,
  1864. labels: Optional[torch.Tensor] = None,
  1865. ) -> Union[tuple, XVectorOutput]:
  1866. r"""
  1867. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1868. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  1869. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1870. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1871. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  1872. into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details.
  1873. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1874. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1875. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1876. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1877. """
  1878. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1879. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1880. outputs = self.wav2vec2(
  1881. input_values,
  1882. attention_mask=attention_mask,
  1883. output_attentions=output_attentions,
  1884. output_hidden_states=output_hidden_states,
  1885. return_dict=return_dict,
  1886. )
  1887. if self.config.use_weighted_layer_sum:
  1888. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1889. hidden_states = torch.stack(hidden_states, dim=1)
  1890. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1891. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1892. else:
  1893. hidden_states = outputs[0]
  1894. hidden_states = self.projector(hidden_states)
  1895. for tdnn_layer in self.tdnn:
  1896. hidden_states = tdnn_layer(hidden_states)
  1897. # Statistic Pooling
  1898. if attention_mask is None:
  1899. mean_features = hidden_states.mean(dim=1)
  1900. std_features = hidden_states.std(dim=1)
  1901. else:
  1902. feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
  1903. tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
  1904. mean_features = []
  1905. std_features = []
  1906. for i, length in enumerate(tdnn_output_lengths):
  1907. mean_features.append(hidden_states[i, :length].mean(dim=0))
  1908. std_features.append(hidden_states[i, :length].std(dim=0))
  1909. mean_features = torch.stack(mean_features)
  1910. std_features = torch.stack(std_features)
  1911. statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
  1912. output_embeddings = self.feature_extractor(statistic_pooling)
  1913. logits = self.classifier(output_embeddings)
  1914. loss = None
  1915. if labels is not None:
  1916. loss = self.objective(logits, labels)
  1917. if not return_dict:
  1918. output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
  1919. return ((loss,) + output) if loss is not None else output
  1920. return XVectorOutput(
  1921. loss=loss,
  1922. logits=logits,
  1923. embeddings=output_embeddings,
  1924. hidden_states=outputs.hidden_states,
  1925. attentions=outputs.attentions,
  1926. )
  1927. __all__ = [
  1928. "Wav2Vec2ForAudioFrameClassification",
  1929. "Wav2Vec2ForCTC",
  1930. "Wav2Vec2ForMaskedLM",
  1931. "Wav2Vec2ForPreTraining",
  1932. "Wav2Vec2ForSequenceClassification",
  1933. "Wav2Vec2ForXVector",
  1934. "Wav2Vec2Model",
  1935. "Wav2Vec2PreTrainedModel",
  1936. ]