modeling_wavlm.py 71 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/wavlm/modular_wavlm.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_wavlm.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. import math
  8. import warnings
  9. from typing import Optional, Union
  10. import numpy as np
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. from torch.nn import CrossEntropyLoss
  15. from ...activations import ACT2FN
  16. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  17. from ...integrations.fsdp import is_fsdp_managed_module
  18. from ...modeling_layers import GradientCheckpointingLayer
  19. from ...modeling_outputs import (
  20. BaseModelOutput,
  21. CausalLMOutput,
  22. SequenceClassifierOutput,
  23. TokenClassifierOutput,
  24. Wav2Vec2BaseModelOutput,
  25. XVectorOutput,
  26. )
  27. from ...modeling_utils import PreTrainedModel
  28. from ...utils import auto_docstring, is_peft_available, logging
  29. from .configuration_wavlm import WavLMConfig
  30. logger = logging.get_logger(__name__)
  31. class WavLMSamePadLayer(nn.Module):
  32. def __init__(self, num_conv_pos_embeddings):
  33. super().__init__()
  34. self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
  35. def forward(self, hidden_states):
  36. if self.num_pad_remove > 0:
  37. hidden_states = hidden_states[:, :, : -self.num_pad_remove]
  38. return hidden_states
  39. class WavLMPositionalConvEmbedding(nn.Module):
  40. def __init__(self, config):
  41. super().__init__()
  42. self.conv = nn.Conv1d(
  43. config.hidden_size,
  44. config.hidden_size,
  45. kernel_size=config.num_conv_pos_embeddings,
  46. padding=config.num_conv_pos_embeddings // 2,
  47. groups=config.num_conv_pos_embedding_groups,
  48. )
  49. weight_norm = nn.utils.weight_norm
  50. if hasattr(nn.utils.parametrizations, "weight_norm"):
  51. weight_norm = nn.utils.parametrizations.weight_norm
  52. if is_deepspeed_zero3_enabled():
  53. import deepspeed
  54. with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
  55. self.conv = weight_norm(self.conv, name="weight", dim=2)
  56. if hasattr(self.conv, "parametrizations"):
  57. weight_g = self.conv.parametrizations.weight.original0
  58. weight_v = self.conv.parametrizations.weight.original1
  59. else:
  60. weight_g = self.conv.weight_g
  61. weight_v = self.conv.weight_v
  62. deepspeed.zero.register_external_parameter(self, weight_v)
  63. deepspeed.zero.register_external_parameter(self, weight_g)
  64. else:
  65. self.conv = weight_norm(self.conv, name="weight", dim=2)
  66. self.padding = WavLMSamePadLayer(config.num_conv_pos_embeddings)
  67. self.activation = ACT2FN[config.feat_extract_activation]
  68. def forward(self, hidden_states):
  69. hidden_states = hidden_states.transpose(1, 2)
  70. hidden_states = self.conv(hidden_states)
  71. hidden_states = self.padding(hidden_states)
  72. hidden_states = self.activation(hidden_states)
  73. hidden_states = hidden_states.transpose(1, 2)
  74. return hidden_states
  75. class WavLMFeatureProjection(nn.Module):
  76. def __init__(self, config):
  77. super().__init__()
  78. self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
  79. self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
  80. self.dropout = nn.Dropout(config.feat_proj_dropout)
  81. def forward(self, hidden_states):
  82. # non-projected hidden states are needed for quantization
  83. norm_hidden_states = self.layer_norm(hidden_states)
  84. hidden_states = self.projection(norm_hidden_states)
  85. hidden_states = self.dropout(hidden_states)
  86. return hidden_states, norm_hidden_states
  87. class WavLMAttention(nn.Module):
  88. """Multi-headed attention from 'Attention Is All You Need' paper"""
  89. def __init__(
  90. self,
  91. embed_dim: int,
  92. num_heads: int,
  93. dropout: float = 0.0,
  94. num_buckets: int = 320,
  95. max_distance: int = 800,
  96. has_relative_position_bias: bool = True,
  97. ):
  98. super().__init__()
  99. self.embed_dim = embed_dim
  100. self.num_heads = num_heads
  101. self.dropout = dropout
  102. self.head_dim = embed_dim // num_heads
  103. if (self.head_dim * num_heads) != self.embed_dim:
  104. raise ValueError(
  105. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  106. f" and `num_heads`: {num_heads})."
  107. )
  108. self.scaling = self.head_dim**-0.5
  109. self.k_proj = nn.Linear(embed_dim, embed_dim)
  110. self.v_proj = nn.Linear(embed_dim, embed_dim)
  111. self.q_proj = nn.Linear(embed_dim, embed_dim)
  112. self.out_proj = nn.Linear(embed_dim, embed_dim)
  113. self.num_buckets = num_buckets
  114. self.max_distance = max_distance
  115. self.gru_rel_pos_const = nn.Parameter(torch.ones(1, self.num_heads, 1, 1))
  116. self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8)
  117. if has_relative_position_bias:
  118. self.rel_attn_embed = nn.Embedding(self.num_buckets, self.num_heads)
  119. def forward(
  120. self,
  121. hidden_states: torch.Tensor,
  122. attention_mask: Optional[torch.Tensor] = None,
  123. position_bias: Optional[torch.Tensor] = None,
  124. output_attentions: bool = False,
  125. index=0,
  126. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  127. """Attention layer with relative attention"""
  128. bsz, tgt_len, _ = hidden_states.size()
  129. # first pass of attention layer creates position bias
  130. if position_bias is None:
  131. position_bias = self.compute_bias(tgt_len, tgt_len)
  132. position_bias = (
  133. position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, tgt_len)
  134. )
  135. # Compute relative position bias:
  136. # 1) get reshape hidden_states
  137. gated_hidden_states = hidden_states.view(hidden_states.shape[:-1] + (self.num_heads, -1))
  138. gated_hidden_states = gated_hidden_states.permute(0, 2, 1, 3)
  139. # 2) project hidden states
  140. relative_position_proj = self.gru_rel_pos_linear(gated_hidden_states)
  141. relative_position_proj = relative_position_proj.view(gated_hidden_states.shape[:-1] + (2, 4)).sum(-1)
  142. # 3) compute gate for position bias from projected hidden states
  143. gate_a, gate_b = torch.sigmoid(relative_position_proj).chunk(2, dim=-1)
  144. gate_output = gate_a * (gate_b * self.gru_rel_pos_const - 1.0) + 2.0
  145. # 4) apply gate to position bias to compute gated position_bias
  146. gated_position_bias = gate_output.view(bsz * self.num_heads, -1, 1) * position_bias
  147. gated_position_bias = gated_position_bias.view((-1, tgt_len, tgt_len))
  148. attn_output, attn_weights = self.torch_multi_head_self_attention(
  149. hidden_states, attention_mask, gated_position_bias, output_attentions
  150. )
  151. return attn_output, attn_weights, position_bias
  152. def torch_multi_head_self_attention(
  153. self,
  154. hidden_states: torch.FloatTensor,
  155. attention_mask: Union[torch.LongTensor, torch.BoolTensor],
  156. gated_position_bias: torch.FloatTensor,
  157. output_attentions: bool,
  158. ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
  159. """simple wrapper around torch's multi_head_attention_forward function"""
  160. # self-attention assumes q = k = v
  161. query = key = value = hidden_states.transpose(0, 1)
  162. key_padding_mask = attention_mask.ne(1) if attention_mask is not None else None
  163. # disable bias and add_zero_attn
  164. bias_k = bias_v = None
  165. add_zero_attn = False
  166. # PyTorch 1.3.0 has F.multi_head_attention_forward defined
  167. # so no problem with backwards compatibility
  168. attn_output, attn_weights = F.multi_head_attention_forward(
  169. query,
  170. key,
  171. value,
  172. self.embed_dim,
  173. self.num_heads,
  174. torch.empty([0]),
  175. torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
  176. bias_k,
  177. bias_v,
  178. add_zero_attn,
  179. self.dropout,
  180. self.out_proj.weight,
  181. self.out_proj.bias,
  182. self.training,
  183. key_padding_mask,
  184. output_attentions,
  185. gated_position_bias,
  186. use_separate_proj_weight=True,
  187. q_proj_weight=self.q_proj.weight,
  188. k_proj_weight=self.k_proj.weight,
  189. v_proj_weight=self.v_proj.weight,
  190. )
  191. # [Seq_Len, Batch Size, ...] -> [Batch Size, Seq_Len, ...]
  192. attn_output = attn_output.transpose(0, 1)
  193. if attn_weights is not None:
  194. # IMPORTANT: Attention weights are averaged weights
  195. # here which should not be the case. This is an open issue
  196. # on PyTorch: https://github.com/pytorch/pytorch/issues/32590
  197. attn_weights = attn_weights[:, None].broadcast_to(
  198. attn_weights.shape[:1] + (self.num_heads,) + attn_weights.shape[1:]
  199. )
  200. return attn_output, attn_weights
  201. def compute_bias(self, query_length: int, key_length: int) -> torch.FloatTensor:
  202. context_position = torch.arange(query_length, dtype=torch.long)[:, None]
  203. memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
  204. relative_position = memory_position - context_position
  205. relative_position_bucket = self._relative_positions_bucket(relative_position)
  206. relative_position_bucket = relative_position_bucket.to(self.rel_attn_embed.weight.device)
  207. values = self.rel_attn_embed(relative_position_bucket)
  208. values = values.permute([2, 0, 1])
  209. return values
  210. def _relative_positions_bucket(self, relative_positions: torch.FloatTensor) -> torch.FloatTensor:
  211. num_buckets = self.num_buckets // 2
  212. relative_buckets = (relative_positions > 0).to(torch.long) * num_buckets
  213. relative_positions = torch.abs(relative_positions)
  214. max_exact = num_buckets // 2
  215. is_small = relative_positions < max_exact
  216. relative_positions_if_large = torch.log(relative_positions.float() / max_exact)
  217. relative_positions_if_large = relative_positions_if_large / math.log(self.max_distance / max_exact)
  218. relative_positions_if_large = relative_positions_if_large * (num_buckets - max_exact)
  219. relative_position_if_large = (max_exact + relative_positions_if_large).to(torch.long)
  220. relative_position_if_large = torch.min(
  221. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  222. )
  223. relative_buckets += torch.where(is_small, relative_positions, relative_position_if_large)
  224. return relative_buckets
  225. class WavLMFeedForward(nn.Module):
  226. def __init__(self, config):
  227. super().__init__()
  228. self.intermediate_dropout = nn.Dropout(config.activation_dropout)
  229. self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
  230. if isinstance(config.hidden_act, str):
  231. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  232. else:
  233. self.intermediate_act_fn = config.hidden_act
  234. self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
  235. self.output_dropout = nn.Dropout(config.hidden_dropout)
  236. def forward(self, hidden_states):
  237. hidden_states = self.intermediate_dense(hidden_states)
  238. hidden_states = self.intermediate_act_fn(hidden_states)
  239. hidden_states = self.intermediate_dropout(hidden_states)
  240. hidden_states = self.output_dense(hidden_states)
  241. hidden_states = self.output_dropout(hidden_states)
  242. return hidden_states
  243. class WavLMEncoderLayer(GradientCheckpointingLayer):
  244. def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
  245. super().__init__()
  246. self.attention = WavLMAttention(
  247. embed_dim=config.hidden_size,
  248. num_heads=config.num_attention_heads,
  249. dropout=config.attention_dropout,
  250. num_buckets=config.num_buckets,
  251. max_distance=config.max_bucket_distance,
  252. has_relative_position_bias=has_relative_position_bias,
  253. )
  254. self.dropout = nn.Dropout(config.hidden_dropout)
  255. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  256. self.feed_forward = WavLMFeedForward(config)
  257. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  258. def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False, index=0):
  259. attn_residual = hidden_states
  260. hidden_states, attn_weights, position_bias = self.attention(
  261. hidden_states,
  262. attention_mask=attention_mask,
  263. position_bias=position_bias,
  264. output_attentions=output_attentions,
  265. index=index,
  266. )
  267. hidden_states = self.dropout(hidden_states)
  268. hidden_states = attn_residual + hidden_states
  269. hidden_states = self.layer_norm(hidden_states)
  270. hidden_states = hidden_states + self.feed_forward(hidden_states)
  271. hidden_states = self.final_layer_norm(hidden_states)
  272. outputs = (hidden_states, position_bias)
  273. if output_attentions:
  274. outputs += (attn_weights,)
  275. return outputs
  276. class WavLMEncoderLayerStableLayerNorm(GradientCheckpointingLayer):
  277. def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
  278. super().__init__()
  279. self.attention = WavLMAttention(
  280. embed_dim=config.hidden_size,
  281. num_heads=config.num_attention_heads,
  282. dropout=config.attention_dropout,
  283. num_buckets=config.num_buckets,
  284. max_distance=config.max_bucket_distance,
  285. has_relative_position_bias=has_relative_position_bias,
  286. )
  287. self.dropout = nn.Dropout(config.hidden_dropout)
  288. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  289. self.feed_forward = WavLMFeedForward(config)
  290. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  291. def forward(self, hidden_states, attention_mask=None, position_bias=None, output_attentions=False):
  292. attn_residual = hidden_states
  293. hidden_states = self.layer_norm(hidden_states)
  294. hidden_states, attn_weights, position_bias = self.attention(
  295. hidden_states,
  296. attention_mask=attention_mask,
  297. position_bias=position_bias,
  298. output_attentions=output_attentions,
  299. )
  300. hidden_states = self.dropout(hidden_states)
  301. hidden_states = attn_residual + hidden_states
  302. hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
  303. outputs = (hidden_states, position_bias)
  304. if output_attentions:
  305. outputs += (attn_weights,)
  306. return outputs
  307. class WavLMEncoder(nn.Module):
  308. def __init__(self, config):
  309. super().__init__()
  310. self.config = config
  311. self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
  312. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  313. self.dropout = nn.Dropout(config.hidden_dropout)
  314. self.layers = nn.ModuleList(
  315. [WavLMEncoderLayer(config, has_relative_position_bias=(i == 0)) for i in range(config.num_hidden_layers)]
  316. )
  317. self.gradient_checkpointing = False
  318. def forward(
  319. self,
  320. hidden_states,
  321. attention_mask=None,
  322. output_attentions=False,
  323. output_hidden_states=False,
  324. return_dict=True,
  325. ):
  326. all_hidden_states = () if output_hidden_states else None
  327. all_self_attentions = () if output_attentions else None
  328. if attention_mask is not None:
  329. # make sure padded tokens output 0
  330. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  331. hidden_states[~expand_attention_mask] = 0
  332. position_embeddings = self.pos_conv_embed(hidden_states)
  333. hidden_states = hidden_states + position_embeddings
  334. hidden_states = self.layer_norm(hidden_states)
  335. hidden_states = self.dropout(hidden_states)
  336. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  337. position_bias = None
  338. for i, layer in enumerate(self.layers):
  339. if output_hidden_states:
  340. all_hidden_states = all_hidden_states + (hidden_states,)
  341. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  342. dropout_probability = torch.rand([])
  343. skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
  344. if not skip_the_layer or synced_gpus:
  345. # under fsdp or deepspeed zero3 all gpus must run in sync
  346. layer_outputs = layer(
  347. hidden_states,
  348. attention_mask=attention_mask,
  349. position_bias=position_bias,
  350. output_attentions=output_attentions,
  351. index=i,
  352. )
  353. hidden_states, position_bias = layer_outputs[:2]
  354. if skip_the_layer:
  355. layer_outputs = (None, None, None)
  356. if output_attentions:
  357. all_self_attentions = all_self_attentions + (layer_outputs[2],)
  358. if output_hidden_states:
  359. all_hidden_states = all_hidden_states + (hidden_states,)
  360. if not return_dict:
  361. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  362. return BaseModelOutput(
  363. last_hidden_state=hidden_states,
  364. hidden_states=all_hidden_states,
  365. attentions=all_self_attentions,
  366. )
  367. class WavLMEncoderStableLayerNorm(nn.Module):
  368. def __init__(self, config):
  369. super().__init__()
  370. self.config = config
  371. self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
  372. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  373. self.dropout = nn.Dropout(config.hidden_dropout)
  374. self.layers = nn.ModuleList(
  375. [
  376. WavLMEncoderLayerStableLayerNorm(config, has_relative_position_bias=(i == 0))
  377. for i in range(config.num_hidden_layers)
  378. ]
  379. )
  380. self.gradient_checkpointing = False
  381. def forward(
  382. self,
  383. hidden_states,
  384. attention_mask=None,
  385. output_attentions=False,
  386. output_hidden_states=False,
  387. return_dict=True,
  388. ):
  389. all_hidden_states = () if output_hidden_states else None
  390. all_self_attentions = () if output_attentions else None
  391. if attention_mask is not None:
  392. # make sure padded tokens are not attended to
  393. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  394. hidden_states[~expand_attention_mask] = 0
  395. position_embeddings = self.pos_conv_embed(hidden_states)
  396. hidden_states = hidden_states + position_embeddings
  397. hidden_states = self.dropout(hidden_states)
  398. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  399. position_bias = None
  400. for i, layer in enumerate(self.layers):
  401. if output_hidden_states:
  402. all_hidden_states = all_hidden_states + (hidden_states,)
  403. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  404. dropout_probability = torch.rand([])
  405. skip_the_layer = self.training and i > 0 and (dropout_probability < self.config.layerdrop)
  406. if not skip_the_layer or synced_gpus:
  407. # under fsdp or deepspeed zero3 all gpus must run in sync
  408. # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
  409. layer_outputs = layer(
  410. hidden_states,
  411. attention_mask=attention_mask,
  412. output_attentions=output_attentions,
  413. position_bias=position_bias,
  414. )
  415. hidden_states, position_bias = layer_outputs[:2]
  416. if skip_the_layer:
  417. layer_outputs = (None, None, None)
  418. if output_attentions:
  419. all_self_attentions = all_self_attentions + (layer_outputs[2],)
  420. hidden_states = self.layer_norm(hidden_states)
  421. if output_hidden_states:
  422. all_hidden_states = all_hidden_states + (hidden_states,)
  423. if not return_dict:
  424. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  425. return BaseModelOutput(
  426. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions
  427. )
  428. class WavLMGumbelVectorQuantizer(nn.Module):
  429. """
  430. Vector quantization using gumbel softmax. See [CATEGORICAL REPARAMETERIZATION WITH
  431. GUMBEL-SOFTMAX](https://huggingface.co/papers/1611.01144) for more information.
  432. """
  433. def __init__(self, config):
  434. super().__init__()
  435. self.num_groups = config.num_codevector_groups
  436. self.num_vars = config.num_codevectors_per_group
  437. if config.codevector_dim % self.num_groups != 0:
  438. raise ValueError(
  439. f"`config.codevector_dim {config.codevector_dim} must be divisible"
  440. f" by `config.num_codevector_groups` {self.num_groups} "
  441. "for concatenation."
  442. )
  443. # storage for codebook variables (codewords)
  444. self.codevectors = nn.Parameter(
  445. torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups)
  446. )
  447. self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars)
  448. # can be decayed for training
  449. self.temperature = 2
  450. @staticmethod
  451. def _compute_perplexity(probs):
  452. marginal_probs = probs.mean(dim=0)
  453. perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
  454. return perplexity
  455. def forward(self, hidden_states):
  456. batch_size, sequence_length, hidden_size = hidden_states.shape
  457. # project to codevector dim
  458. hidden_states = self.weight_proj(hidden_states)
  459. hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
  460. if self.training:
  461. # sample code vector probs via gumbel in differentiateable way
  462. codevector_probs = nn.functional.gumbel_softmax(hidden_states.float(), tau=self.temperature, hard=True)
  463. codevector_probs = codevector_probs.type_as(hidden_states)
  464. # compute perplexity
  465. codevector_soft_dist = torch.softmax(
  466. hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
  467. )
  468. perplexity = self._compute_perplexity(codevector_soft_dist)
  469. else:
  470. # take argmax in non-differentiable way
  471. # comptute hard codevector distribution (one hot)
  472. codevector_idx = hidden_states.argmax(dim=-1)
  473. codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
  474. -1, codevector_idx.view(-1, 1), 1.0
  475. )
  476. codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
  477. perplexity = self._compute_perplexity(codevector_probs)
  478. codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
  479. # use probs to retrieve codevectors
  480. codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
  481. codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
  482. codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
  483. return codevectors, perplexity
  484. @auto_docstring
  485. class WavLMPreTrainedModel(PreTrainedModel):
  486. config: WavLMConfig
  487. base_model_prefix = "wavlm"
  488. main_input_name = "input_values"
  489. supports_gradient_checkpointing = True
  490. _supports_flash_attn = False
  491. _supports_sdpa = False
  492. _supports_flex_attn = False
  493. def _init_weights(self, module):
  494. """Initialize the weights"""
  495. # gumbel softmax requires special init
  496. if isinstance(module, WavLMGumbelVectorQuantizer):
  497. module.weight_proj.weight.data.normal_(mean=0.0, std=1)
  498. module.weight_proj.bias.data.zero_()
  499. nn.init.uniform_(module.codevectors)
  500. elif isinstance(module, WavLMPositionalConvEmbedding):
  501. nn.init.normal_(
  502. module.conv.weight,
  503. mean=0,
  504. std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
  505. )
  506. nn.init.constant_(module.conv.bias, 0)
  507. elif isinstance(module, WavLMFeatureProjection):
  508. k = math.sqrt(1 / module.projection.in_features)
  509. nn.init.uniform_(module.projection.weight, a=-k, b=k)
  510. nn.init.uniform_(module.projection.bias, a=-k, b=k)
  511. elif isinstance(module, nn.Linear):
  512. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  513. if module.bias is not None:
  514. module.bias.data.zero_()
  515. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  516. module.bias.data.zero_()
  517. module.weight.data.fill_(1.0)
  518. elif isinstance(module, nn.Conv1d):
  519. nn.init.kaiming_normal_(module.weight)
  520. if module.bias is not None:
  521. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  522. nn.init.uniform_(module.bias, a=-k, b=k)
  523. def _get_feat_extract_output_lengths(
  524. self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
  525. ):
  526. """
  527. Computes the output length of the convolutional layers
  528. """
  529. add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
  530. def _conv_out_length(input_length, kernel_size, stride):
  531. # 1D convolutional layer output length formula taken
  532. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  533. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  534. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  535. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  536. if add_adapter:
  537. for _ in range(self.config.num_adapter_layers):
  538. input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
  539. return input_lengths
  540. def _get_feature_vector_attention_mask(
  541. self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
  542. ):
  543. # Effectively attention_mask.sum(-1), but not inplace to be able to run
  544. # on inference mode.
  545. non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
  546. output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
  547. output_lengths = output_lengths.to(torch.long)
  548. batch_size = attention_mask.shape[0]
  549. attention_mask = torch.zeros(
  550. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  551. )
  552. # these two operations makes sure that all values before the output lengths idxs are attended to
  553. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  554. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  555. return attention_mask
  556. class WavLMNoLayerNormConvLayer(GradientCheckpointingLayer):
  557. def __init__(self, config, layer_id=0):
  558. super().__init__()
  559. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  560. self.out_conv_dim = config.conv_dim[layer_id]
  561. self.conv = nn.Conv1d(
  562. self.in_conv_dim,
  563. self.out_conv_dim,
  564. kernel_size=config.conv_kernel[layer_id],
  565. stride=config.conv_stride[layer_id],
  566. bias=config.conv_bias,
  567. )
  568. self.activation = ACT2FN[config.feat_extract_activation]
  569. def forward(self, hidden_states):
  570. hidden_states = self.conv(hidden_states)
  571. hidden_states = self.activation(hidden_states)
  572. return hidden_states
  573. class WavLMLayerNormConvLayer(GradientCheckpointingLayer):
  574. def __init__(self, config, layer_id=0):
  575. super().__init__()
  576. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  577. self.out_conv_dim = config.conv_dim[layer_id]
  578. self.conv = nn.Conv1d(
  579. self.in_conv_dim,
  580. self.out_conv_dim,
  581. kernel_size=config.conv_kernel[layer_id],
  582. stride=config.conv_stride[layer_id],
  583. bias=config.conv_bias,
  584. )
  585. self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
  586. self.activation = ACT2FN[config.feat_extract_activation]
  587. def forward(self, hidden_states):
  588. hidden_states = self.conv(hidden_states)
  589. hidden_states = hidden_states.transpose(-2, -1)
  590. hidden_states = self.layer_norm(hidden_states)
  591. hidden_states = hidden_states.transpose(-2, -1)
  592. hidden_states = self.activation(hidden_states)
  593. return hidden_states
  594. class WavLMGroupNormConvLayer(GradientCheckpointingLayer):
  595. def __init__(self, config, layer_id=0):
  596. super().__init__()
  597. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  598. self.out_conv_dim = config.conv_dim[layer_id]
  599. self.conv = nn.Conv1d(
  600. self.in_conv_dim,
  601. self.out_conv_dim,
  602. kernel_size=config.conv_kernel[layer_id],
  603. stride=config.conv_stride[layer_id],
  604. bias=config.conv_bias,
  605. )
  606. self.activation = ACT2FN[config.feat_extract_activation]
  607. self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
  608. def forward(self, hidden_states):
  609. hidden_states = self.conv(hidden_states)
  610. hidden_states = self.layer_norm(hidden_states)
  611. hidden_states = self.activation(hidden_states)
  612. return hidden_states
  613. class WavLMFeatureEncoder(nn.Module):
  614. """Construct the features from raw audio waveform"""
  615. def __init__(self, config):
  616. super().__init__()
  617. if config.feat_extract_norm == "group":
  618. conv_layers = [WavLMGroupNormConvLayer(config, layer_id=0)] + [
  619. WavLMNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
  620. ]
  621. elif config.feat_extract_norm == "layer":
  622. conv_layers = [WavLMLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
  623. else:
  624. raise ValueError(
  625. f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
  626. )
  627. self.conv_layers = nn.ModuleList(conv_layers)
  628. self.gradient_checkpointing = False
  629. self._requires_grad = True
  630. def _freeze_parameters(self):
  631. for param in self.parameters():
  632. param.requires_grad = False
  633. self._requires_grad = False
  634. def forward(self, input_values):
  635. hidden_states = input_values[:, None]
  636. # make sure hidden_states require grad for gradient_checkpointing
  637. if self._requires_grad and self.training:
  638. hidden_states.requires_grad = True
  639. for conv_layer in self.conv_layers:
  640. hidden_states = conv_layer(hidden_states)
  641. return hidden_states
  642. class WavLMAdapterLayer(nn.Module):
  643. def __init__(self, config):
  644. super().__init__()
  645. self.conv = nn.Conv1d(
  646. config.output_hidden_size,
  647. 2 * config.output_hidden_size,
  648. config.adapter_kernel_size,
  649. stride=config.adapter_stride,
  650. padding=1,
  651. )
  652. def forward(self, hidden_states):
  653. hidden_states = self.conv(hidden_states)
  654. hidden_states = nn.functional.glu(hidden_states, dim=1)
  655. return hidden_states
  656. class WavLMAdapter(nn.Module):
  657. def __init__(self, config):
  658. super().__init__()
  659. # feature dim might need to be down-projected
  660. if config.output_hidden_size != config.hidden_size:
  661. self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
  662. self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
  663. else:
  664. self.proj = self.proj_layer_norm = None
  665. self.layers = nn.ModuleList(WavLMAdapterLayer(config) for _ in range(config.num_adapter_layers))
  666. self.layerdrop = config.layerdrop
  667. def forward(self, hidden_states):
  668. # down project hidden_states if necessary
  669. if self.proj is not None and self.proj_layer_norm is not None:
  670. hidden_states = self.proj(hidden_states)
  671. hidden_states = self.proj_layer_norm(hidden_states)
  672. hidden_states = hidden_states.transpose(1, 2)
  673. for layer in self.layers:
  674. layerdrop_prob = np.random.random()
  675. if not self.training or (layerdrop_prob > self.layerdrop):
  676. hidden_states = layer(hidden_states)
  677. hidden_states = hidden_states.transpose(1, 2)
  678. return hidden_states
  679. def _compute_mask_indices(
  680. shape: tuple[int, int],
  681. mask_prob: float,
  682. mask_length: int,
  683. attention_mask: Optional[torch.LongTensor] = None,
  684. min_masks: int = 0,
  685. ) -> np.ndarray:
  686. """
  687. Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
  688. ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
  689. CPU as part of the preprocessing during training.
  690. Args:
  691. shape: The shape for which to compute masks. This should be of a tuple of size 2 where
  692. the first element is the batch size and the second element is the length of the axis to span.
  693. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
  694. independently generated mask spans of length `mask_length` is computed by
  695. `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
  696. actual percentage will be smaller.
  697. mask_length: size of the mask
  698. min_masks: minimum number of masked spans
  699. attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
  700. each batch dimension.
  701. """
  702. batch_size, sequence_length = shape
  703. if mask_length < 1:
  704. raise ValueError("`mask_length` has to be bigger than 0.")
  705. if mask_length > sequence_length:
  706. raise ValueError(
  707. f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
  708. f" and `sequence_length`: {sequence_length}`"
  709. )
  710. # epsilon is used for probabilistic rounding
  711. epsilon = np.random.rand(1).item()
  712. def compute_num_masked_span(input_length):
  713. """Given input length, compute how many spans should be masked"""
  714. num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
  715. num_masked_span = max(num_masked_span, min_masks)
  716. # make sure num masked span <= sequence_length
  717. if num_masked_span * mask_length > sequence_length:
  718. num_masked_span = sequence_length // mask_length
  719. # make sure num_masked span is also <= input_length - (mask_length - 1)
  720. if input_length - (mask_length - 1) < num_masked_span:
  721. num_masked_span = max(input_length - (mask_length - 1), 0)
  722. return num_masked_span
  723. # compute number of masked spans in batch
  724. input_lengths = (
  725. attention_mask.detach().sum(-1).tolist()
  726. if attention_mask is not None
  727. else [sequence_length for _ in range(batch_size)]
  728. )
  729. # SpecAugment mask to fill
  730. spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
  731. spec_aug_mask_idxs = []
  732. max_num_masked_span = compute_num_masked_span(sequence_length)
  733. if max_num_masked_span == 0:
  734. return spec_aug_mask
  735. for input_length in input_lengths:
  736. # compute num of masked spans for this input
  737. num_masked_span = compute_num_masked_span(input_length)
  738. # get random indices to mask
  739. spec_aug_mask_idx = np.random.choice(
  740. np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
  741. )
  742. # pick first sampled index that will serve as a dummy index to pad vector
  743. # to ensure same dimension for all batches due to probabilistic rounding
  744. # Picking first sample just pads those vectors twice.
  745. if len(spec_aug_mask_idx) == 0:
  746. # this case can only happen if `input_length` is strictly smaller then
  747. # `sequence_length` in which case the last token has to be a padding
  748. # token which we can use as a dummy mask id
  749. dummy_mask_idx = sequence_length - 1
  750. else:
  751. dummy_mask_idx = spec_aug_mask_idx[0]
  752. spec_aug_mask_idx = np.concatenate(
  753. [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
  754. )
  755. spec_aug_mask_idxs.append(spec_aug_mask_idx)
  756. spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
  757. # expand masked indices to masked spans
  758. spec_aug_mask_idxs = np.broadcast_to(
  759. spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
  760. )
  761. spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
  762. # add offset to the starting indexes so that indexes now create a span
  763. offsets = np.arange(mask_length)[None, None, :]
  764. offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
  765. batch_size, max_num_masked_span * mask_length
  766. )
  767. spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
  768. # ensure that we cannot have indices larger than sequence_length
  769. if spec_aug_mask_idxs.max() > sequence_length - 1:
  770. spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
  771. # scatter indices to mask
  772. np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
  773. return spec_aug_mask
  774. WavLMBaseModelOutput = Wav2Vec2BaseModelOutput
  775. @auto_docstring
  776. class WavLMModel(WavLMPreTrainedModel):
  777. def __init__(self, config: WavLMConfig):
  778. super().__init__(config)
  779. self.config = config
  780. self.feature_extractor = WavLMFeatureEncoder(config)
  781. self.feature_projection = WavLMFeatureProjection(config)
  782. # model only needs masking vector if mask prob is > 0.0
  783. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  784. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  785. if config.do_stable_layer_norm:
  786. self.encoder = WavLMEncoderStableLayerNorm(config)
  787. else:
  788. self.encoder = WavLMEncoder(config)
  789. self.adapter = WavLMAdapter(config) if config.add_adapter else None
  790. # Initialize weights and apply final processing
  791. self.post_init()
  792. def freeze_feature_extractor(self):
  793. """
  794. Calling this function will disable the gradient computation for the feature encoder so that its parameters will
  795. not be updated during training.
  796. """
  797. warnings.warn(
  798. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  799. "Please use the equivalent `freeze_feature_encoder` method instead.",
  800. FutureWarning,
  801. )
  802. self.freeze_feature_encoder()
  803. def freeze_feature_encoder(self):
  804. """
  805. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  806. not be updated during training.
  807. """
  808. self.feature_extractor._freeze_parameters()
  809. def _mask_hidden_states(
  810. self,
  811. hidden_states: torch.FloatTensor,
  812. mask_time_indices: Optional[torch.FloatTensor] = None,
  813. attention_mask: Optional[torch.LongTensor] = None,
  814. ):
  815. """
  816. Masks extracted features along time axis and/or along feature axis according to
  817. [SpecAugment](https://huggingface.co/papers/1904.08779).
  818. """
  819. # `config.apply_spec_augment` can set masking to False
  820. if not getattr(self.config, "apply_spec_augment", True):
  821. return hidden_states
  822. # generate indices & apply SpecAugment along time axis
  823. batch_size, sequence_length, hidden_size = hidden_states.size()
  824. if mask_time_indices is not None:
  825. # apply SpecAugment along time axis with given mask_time_indices
  826. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  827. elif self.config.mask_time_prob > 0 and self.training:
  828. mask_time_indices = _compute_mask_indices(
  829. (batch_size, sequence_length),
  830. mask_prob=self.config.mask_time_prob,
  831. mask_length=self.config.mask_time_length,
  832. attention_mask=attention_mask,
  833. min_masks=self.config.mask_time_min_masks,
  834. )
  835. mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
  836. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  837. if self.config.mask_feature_prob > 0 and self.training:
  838. # generate indices & apply SpecAugment along feature axis
  839. mask_feature_indices = _compute_mask_indices(
  840. (batch_size, hidden_size),
  841. mask_prob=self.config.mask_feature_prob,
  842. mask_length=self.config.mask_feature_length,
  843. min_masks=self.config.mask_feature_min_masks,
  844. )
  845. mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
  846. mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
  847. hidden_states[mask_feature_indices] = 0
  848. return hidden_states
  849. @auto_docstring
  850. def forward(
  851. self,
  852. input_values: Optional[torch.Tensor],
  853. attention_mask: Optional[torch.Tensor] = None,
  854. mask_time_indices: Optional[torch.FloatTensor] = None,
  855. output_attentions: Optional[bool] = None,
  856. output_hidden_states: Optional[bool] = None,
  857. return_dict: Optional[bool] = None,
  858. ) -> Union[tuple, WavLMBaseModelOutput]:
  859. r"""
  860. mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  861. Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
  862. masked extracted features in *config.proj_codevector_dim* space.
  863. """
  864. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  865. output_hidden_states = (
  866. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  867. )
  868. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  869. extract_features = self.feature_extractor(input_values)
  870. extract_features = extract_features.transpose(1, 2)
  871. if attention_mask is not None:
  872. # compute reduced attention_mask corresponding to feature vectors
  873. attention_mask = self._get_feature_vector_attention_mask(
  874. extract_features.shape[1], attention_mask, add_adapter=False
  875. )
  876. hidden_states, extract_features = self.feature_projection(extract_features)
  877. hidden_states = self._mask_hidden_states(
  878. hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
  879. )
  880. encoder_outputs = self.encoder(
  881. hidden_states,
  882. attention_mask=attention_mask,
  883. output_attentions=output_attentions,
  884. output_hidden_states=output_hidden_states,
  885. return_dict=return_dict,
  886. )
  887. hidden_states = encoder_outputs[0]
  888. if self.adapter is not None:
  889. hidden_states = self.adapter(hidden_states)
  890. if not return_dict:
  891. return (hidden_states, extract_features) + encoder_outputs[1:]
  892. return WavLMBaseModelOutput(
  893. last_hidden_state=hidden_states,
  894. extract_features=extract_features,
  895. hidden_states=encoder_outputs.hidden_states,
  896. attentions=encoder_outputs.attentions,
  897. )
  898. _HIDDEN_STATES_START_POSITION = 2
  899. @auto_docstring(
  900. custom_intro="""
  901. WavLM Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
  902. """
  903. )
  904. class WavLMForCTC(WavLMPreTrainedModel):
  905. def __init__(self, config, target_lang: Optional[str] = None):
  906. r"""
  907. target_lang (`str`, *optional*):
  908. Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
  909. adapter.<lang>.bin. Only relevant when using an instance of [`WavLMForCTC`] with adapters. Uses 'eng' by
  910. default.
  911. """
  912. super().__init__(config)
  913. self.wavlm = WavLMModel(config)
  914. self.dropout = nn.Dropout(config.final_dropout)
  915. self.target_lang = target_lang
  916. if config.vocab_size is None:
  917. raise ValueError(
  918. f"You are trying to instantiate {self.__class__} with a configuration that "
  919. "does not define the vocabulary size of the language model head. Please "
  920. "instantiate the model as follows: `WavLMForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
  921. "or define `vocab_size` of your model's configuration."
  922. )
  923. output_hidden_size = (
  924. config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
  925. )
  926. self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
  927. # Initialize weights and apply final processing
  928. self.post_init()
  929. def tie_weights(self):
  930. """
  931. This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
  932. passing `target_lang=...` to `from_pretrained(...)`.
  933. This method is **not** supposed to be called by the user and is prone to be changed in the future.
  934. """
  935. # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
  936. # correctly load adapter layers for WavLM so that we do not have to introduce a new API to
  937. # [`PreTrainedModel`]. While slightly hacky, WavLM never has to tie input and output embeddings, so that it is
  938. # ok to repurpose this function here.
  939. target_lang = self.target_lang
  940. if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
  941. raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
  942. elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
  943. logger.info("By default `target_lang` is set to 'eng'.")
  944. elif target_lang is not None:
  945. self.load_adapter(target_lang, force_load=True)
  946. def freeze_feature_extractor(self):
  947. """
  948. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  949. not be updated during training.
  950. """
  951. warnings.warn(
  952. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  953. "Please use the equivalent `freeze_feature_encoder` method instead.",
  954. FutureWarning,
  955. )
  956. self.freeze_feature_encoder()
  957. def freeze_feature_encoder(self):
  958. """
  959. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  960. not be updated during training.
  961. """
  962. self.wavlm.feature_extractor._freeze_parameters()
  963. def freeze_base_model(self):
  964. """
  965. Calling this function will disable the gradient computation for the base model so that its parameters will not
  966. be updated during training. Only the classification head will be updated.
  967. """
  968. for param in self.wavlm.parameters():
  969. param.requires_grad = False
  970. @auto_docstring
  971. def forward(
  972. self,
  973. input_values: Optional[torch.Tensor],
  974. attention_mask: Optional[torch.Tensor] = None,
  975. output_attentions: Optional[bool] = None,
  976. output_hidden_states: Optional[bool] = None,
  977. return_dict: Optional[bool] = None,
  978. labels: Optional[torch.Tensor] = None,
  979. ) -> Union[tuple, CausalLMOutput]:
  980. r"""
  981. labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
  982. Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
  983. the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
  984. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
  985. config.vocab_size - 1]`.
  986. """
  987. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  988. if labels is not None and labels.max() >= self.config.vocab_size:
  989. raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
  990. outputs = self.wavlm(
  991. input_values,
  992. attention_mask=attention_mask,
  993. output_attentions=output_attentions,
  994. output_hidden_states=output_hidden_states,
  995. return_dict=return_dict,
  996. )
  997. hidden_states = outputs[0]
  998. hidden_states = self.dropout(hidden_states)
  999. logits = self.lm_head(hidden_states)
  1000. loss = None
  1001. if labels is not None:
  1002. # retrieve loss input_lengths from attention_mask
  1003. attention_mask = (
  1004. attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
  1005. )
  1006. input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  1007. # assuming that padded tokens are filled with -100
  1008. # when not being attended to
  1009. labels_mask = labels >= 0
  1010. target_lengths = labels_mask.sum(-1)
  1011. flattened_targets = labels.masked_select(labels_mask)
  1012. # ctc_loss doesn't support fp16
  1013. log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
  1014. with torch.backends.cudnn.flags(enabled=False):
  1015. loss = nn.functional.ctc_loss(
  1016. log_probs,
  1017. flattened_targets,
  1018. input_lengths,
  1019. target_lengths,
  1020. blank=self.config.pad_token_id,
  1021. reduction=self.config.ctc_loss_reduction,
  1022. zero_infinity=self.config.ctc_zero_infinity,
  1023. )
  1024. if not return_dict:
  1025. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1026. return ((loss,) + output) if loss is not None else output
  1027. return CausalLMOutput(
  1028. loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  1029. )
  1030. @auto_docstring(
  1031. custom_intro="""
  1032. WavLM Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
  1033. SUPERB Keyword Spotting.
  1034. """
  1035. )
  1036. class WavLMForSequenceClassification(WavLMPreTrainedModel):
  1037. def __init__(self, config):
  1038. super().__init__(config)
  1039. if hasattr(config, "add_adapter") and config.add_adapter:
  1040. raise ValueError(
  1041. "Sequence classification does not support the use of WavLM adapters (config.add_adapter=True)"
  1042. )
  1043. self.wavlm = WavLMModel(config)
  1044. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1045. if config.use_weighted_layer_sum:
  1046. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1047. self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
  1048. self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
  1049. # Initialize weights and apply final processing
  1050. self.post_init()
  1051. def freeze_feature_extractor(self):
  1052. """
  1053. Calling this function will disable the gradient computation for the feature encoder so that its parameters will
  1054. not be updated during training.
  1055. """
  1056. warnings.warn(
  1057. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1058. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1059. FutureWarning,
  1060. )
  1061. self.freeze_feature_encoder()
  1062. def freeze_feature_encoder(self):
  1063. """
  1064. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1065. not be updated during training.
  1066. """
  1067. self.wavlm.feature_extractor._freeze_parameters()
  1068. def freeze_base_model(self):
  1069. """
  1070. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1071. be updated during training. Only the classification head will be updated.
  1072. """
  1073. for param in self.wavlm.parameters():
  1074. param.requires_grad = False
  1075. @auto_docstring
  1076. def forward(
  1077. self,
  1078. input_values: Optional[torch.Tensor],
  1079. attention_mask: Optional[torch.Tensor] = None,
  1080. output_attentions: Optional[bool] = None,
  1081. output_hidden_states: Optional[bool] = None,
  1082. return_dict: Optional[bool] = None,
  1083. labels: Optional[torch.Tensor] = None,
  1084. ) -> Union[tuple, SequenceClassifierOutput]:
  1085. r"""
  1086. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1087. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  1088. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1089. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1090. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  1091. into a tensor of type `torch.FloatTensor`. See [`WavLMProcessor.__call__`] for details.
  1092. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1093. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1094. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1095. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1096. """
  1097. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1098. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1099. outputs = self.wavlm(
  1100. input_values,
  1101. attention_mask=attention_mask,
  1102. output_attentions=output_attentions,
  1103. output_hidden_states=output_hidden_states,
  1104. return_dict=return_dict,
  1105. )
  1106. if self.config.use_weighted_layer_sum:
  1107. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1108. hidden_states = torch.stack(hidden_states, dim=1)
  1109. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1110. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1111. else:
  1112. hidden_states = outputs[0]
  1113. hidden_states = self.projector(hidden_states)
  1114. if attention_mask is None:
  1115. pooled_output = hidden_states.mean(dim=1)
  1116. else:
  1117. padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
  1118. expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  1119. hidden_states[~expand_padding_mask] = 0.0
  1120. pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
  1121. logits = self.classifier(pooled_output)
  1122. loss = None
  1123. if labels is not None:
  1124. loss_fct = CrossEntropyLoss()
  1125. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1126. if not return_dict:
  1127. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1128. return ((loss,) + output) if loss is not None else output
  1129. return SequenceClassifierOutput(
  1130. loss=loss,
  1131. logits=logits,
  1132. hidden_states=outputs.hidden_states,
  1133. attentions=outputs.attentions,
  1134. )
  1135. @auto_docstring
  1136. class WavLMForAudioFrameClassification(WavLMPreTrainedModel):
  1137. def __init__(self, config):
  1138. super().__init__(config)
  1139. if hasattr(config, "add_adapter") and config.add_adapter:
  1140. raise ValueError(
  1141. "Audio frame classification does not support the use of WavLM adapters (config.add_adapter=True)"
  1142. )
  1143. self.wavlm = WavLMModel(config)
  1144. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1145. if config.use_weighted_layer_sum:
  1146. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1147. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1148. self.num_labels = config.num_labels
  1149. self.init_weights()
  1150. def freeze_feature_extractor(self):
  1151. """
  1152. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1153. not be updated during training.
  1154. """
  1155. warnings.warn(
  1156. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1157. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1158. FutureWarning,
  1159. )
  1160. self.freeze_feature_encoder()
  1161. def freeze_feature_encoder(self):
  1162. """
  1163. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1164. not be updated during training.
  1165. """
  1166. self.wavlm.feature_extractor._freeze_parameters()
  1167. def freeze_base_model(self):
  1168. """
  1169. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1170. be updated during training. Only the classification head will be updated.
  1171. """
  1172. for param in self.wavlm.parameters():
  1173. param.requires_grad = False
  1174. @auto_docstring
  1175. def forward(
  1176. self,
  1177. input_values: Optional[torch.Tensor],
  1178. attention_mask: Optional[torch.Tensor] = None,
  1179. labels: Optional[torch.Tensor] = None,
  1180. output_attentions: Optional[bool] = None,
  1181. output_hidden_states: Optional[bool] = None,
  1182. return_dict: Optional[bool] = None,
  1183. ) -> Union[tuple, TokenClassifierOutput]:
  1184. r"""
  1185. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1186. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  1187. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1188. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1189. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  1190. into a tensor of type `torch.FloatTensor`. See [`WavLMProcessor.__call__`] for details.
  1191. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1192. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1193. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1194. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1195. """
  1196. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1197. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1198. outputs = self.wavlm(
  1199. input_values,
  1200. attention_mask=attention_mask,
  1201. output_attentions=output_attentions,
  1202. output_hidden_states=output_hidden_states,
  1203. return_dict=return_dict,
  1204. )
  1205. if self.config.use_weighted_layer_sum:
  1206. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1207. hidden_states = torch.stack(hidden_states, dim=1)
  1208. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1209. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1210. else:
  1211. hidden_states = outputs[0]
  1212. logits = self.classifier(hidden_states)
  1213. loss = None
  1214. if labels is not None:
  1215. loss_fct = CrossEntropyLoss()
  1216. loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1))
  1217. if not return_dict:
  1218. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1219. return output
  1220. return TokenClassifierOutput(
  1221. loss=loss,
  1222. logits=logits,
  1223. hidden_states=outputs.hidden_states,
  1224. attentions=outputs.attentions,
  1225. )
  1226. class AMSoftmaxLoss(nn.Module):
  1227. def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4):
  1228. super().__init__()
  1229. self.scale = scale
  1230. self.margin = margin
  1231. self.num_labels = num_labels
  1232. self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True)
  1233. self.loss = nn.CrossEntropyLoss()
  1234. def forward(self, hidden_states, labels):
  1235. labels = labels.flatten()
  1236. weight = nn.functional.normalize(self.weight, dim=0)
  1237. hidden_states = nn.functional.normalize(hidden_states, dim=1)
  1238. cos_theta = torch.mm(hidden_states, weight)
  1239. psi = cos_theta - self.margin
  1240. onehot = nn.functional.one_hot(labels, self.num_labels)
  1241. logits = self.scale * torch.where(onehot.bool(), psi, cos_theta)
  1242. loss = self.loss(logits, labels)
  1243. return loss
  1244. class TDNNLayer(nn.Module):
  1245. def __init__(self, config, layer_id=0):
  1246. super().__init__()
  1247. self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id]
  1248. self.out_conv_dim = config.tdnn_dim[layer_id]
  1249. self.kernel_size = config.tdnn_kernel[layer_id]
  1250. self.dilation = config.tdnn_dilation[layer_id]
  1251. self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim)
  1252. self.activation = nn.ReLU()
  1253. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  1254. if is_peft_available():
  1255. from peft.tuners.lora import LoraLayer
  1256. if is_peft_available():
  1257. if isinstance(self.kernel, LoraLayer):
  1258. warnings.warn(
  1259. "Detected LoRA on TDNNLayer. LoRA weights won't be applied due to optimization. "
  1260. "You should exclude TDNNLayer from LoRA's target modules.",
  1261. )
  1262. # for backward compatibility, we keep nn.Linear but call F.conv1d for speed up
  1263. hidden_states = hidden_states.transpose(1, 2)
  1264. weight = self.kernel.weight.view(self.out_conv_dim, self.kernel_size, self.in_conv_dim).transpose(1, 2)
  1265. hidden_states = nn.functional.conv1d(hidden_states, weight, self.kernel.bias, dilation=self.dilation)
  1266. hidden_states = hidden_states.transpose(1, 2)
  1267. hidden_states = self.activation(hidden_states)
  1268. return hidden_states
  1269. @auto_docstring(
  1270. custom_intro="""
  1271. WavLM Model with an XVector feature extraction head on top for tasks like Speaker Verification.
  1272. """
  1273. )
  1274. class WavLMForXVector(WavLMPreTrainedModel):
  1275. def __init__(self, config):
  1276. super().__init__(config)
  1277. self.wavlm = WavLMModel(config)
  1278. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1279. if config.use_weighted_layer_sum:
  1280. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1281. self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0])
  1282. tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))]
  1283. self.tdnn = nn.ModuleList(tdnn_layers)
  1284. self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim)
  1285. self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim)
  1286. self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels)
  1287. self.init_weights()
  1288. def freeze_feature_extractor(self):
  1289. """
  1290. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1291. not be updated during training.
  1292. """
  1293. warnings.warn(
  1294. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  1295. "Please use the equivalent `freeze_feature_encoder` method instead.",
  1296. FutureWarning,
  1297. )
  1298. self.freeze_feature_encoder()
  1299. def freeze_feature_encoder(self):
  1300. """
  1301. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  1302. not be updated during training.
  1303. """
  1304. self.wavlm.feature_extractor._freeze_parameters()
  1305. def freeze_base_model(self):
  1306. """
  1307. Calling this function will disable the gradient computation for the base model so that its parameters will not
  1308. be updated during training. Only the classification head will be updated.
  1309. """
  1310. for param in self.wavlm.parameters():
  1311. param.requires_grad = False
  1312. def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
  1313. """
  1314. Computes the output length of the TDNN layers
  1315. """
  1316. def _conv_out_length(input_length, kernel_size, stride):
  1317. # 1D convolutional layer output length formula taken
  1318. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  1319. return (input_length - kernel_size) // stride + 1
  1320. for kernel_size in self.config.tdnn_kernel:
  1321. input_lengths = _conv_out_length(input_lengths, kernel_size, 1)
  1322. return input_lengths
  1323. @auto_docstring
  1324. def forward(
  1325. self,
  1326. input_values: Optional[torch.Tensor],
  1327. attention_mask: Optional[torch.Tensor] = None,
  1328. output_attentions: Optional[bool] = None,
  1329. output_hidden_states: Optional[bool] = None,
  1330. return_dict: Optional[bool] = None,
  1331. labels: Optional[torch.Tensor] = None,
  1332. ) -> Union[tuple, XVectorOutput]:
  1333. r"""
  1334. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1335. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  1336. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1337. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1338. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  1339. into a tensor of type `torch.FloatTensor`. See [`WavLMProcessor.__call__`] for details.
  1340. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1341. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1342. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1343. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1344. """
  1345. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1346. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1347. outputs = self.wavlm(
  1348. input_values,
  1349. attention_mask=attention_mask,
  1350. output_attentions=output_attentions,
  1351. output_hidden_states=output_hidden_states,
  1352. return_dict=return_dict,
  1353. )
  1354. if self.config.use_weighted_layer_sum:
  1355. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1356. hidden_states = torch.stack(hidden_states, dim=1)
  1357. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1358. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1359. else:
  1360. hidden_states = outputs[0]
  1361. hidden_states = self.projector(hidden_states)
  1362. for tdnn_layer in self.tdnn:
  1363. hidden_states = tdnn_layer(hidden_states)
  1364. # Statistic Pooling
  1365. if attention_mask is None:
  1366. mean_features = hidden_states.mean(dim=1)
  1367. std_features = hidden_states.std(dim=1)
  1368. else:
  1369. feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1))
  1370. tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths)
  1371. mean_features = []
  1372. std_features = []
  1373. for i, length in enumerate(tdnn_output_lengths):
  1374. mean_features.append(hidden_states[i, :length].mean(dim=0))
  1375. std_features.append(hidden_states[i, :length].std(dim=0))
  1376. mean_features = torch.stack(mean_features)
  1377. std_features = torch.stack(std_features)
  1378. statistic_pooling = torch.cat([mean_features, std_features], dim=-1)
  1379. output_embeddings = self.feature_extractor(statistic_pooling)
  1380. logits = self.classifier(output_embeddings)
  1381. loss = None
  1382. if labels is not None:
  1383. loss = self.objective(logits, labels)
  1384. if not return_dict:
  1385. output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:]
  1386. return ((loss,) + output) if loss is not None else output
  1387. return XVectorOutput(
  1388. loss=loss,
  1389. logits=logits,
  1390. embeddings=output_embeddings,
  1391. hidden_states=outputs.hidden_states,
  1392. attentions=outputs.attentions,
  1393. )
  1394. __all__ = [
  1395. "WavLMForAudioFrameClassification",
  1396. "WavLMForCTC",
  1397. "WavLMForSequenceClassification",
  1398. "WavLMForXVector",
  1399. "WavLMModel",
  1400. "WavLMPreTrainedModel",
  1401. ]