modeling_hubert.py 53 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/hubert/modular_hubert.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_hubert.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import warnings
  22. from typing import Callable, Optional, Union
  23. import numpy as np
  24. import torch
  25. import torch.nn as nn
  26. from torch.nn import CrossEntropyLoss
  27. from ...activations import ACT2FN
  28. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  29. from ...integrations.fsdp import is_fsdp_managed_module
  30. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
  31. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  32. from ...modeling_layers import GradientCheckpointingLayer
  33. from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
  34. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  35. from ...processing_utils import Unpack
  36. from ...utils import auto_docstring, is_torch_flex_attn_available, logging
  37. from .configuration_hubert import HubertConfig
  38. if is_torch_flex_attn_available():
  39. from ...integrations.flex_attention import make_flex_block_causal_mask
  40. logger = logging.get_logger(__name__)
  41. class HubertPositionalConvEmbedding(nn.Module):
  42. def __init__(self, config):
  43. super().__init__()
  44. self.conv = nn.Conv1d(
  45. config.hidden_size,
  46. config.hidden_size,
  47. kernel_size=config.num_conv_pos_embeddings,
  48. padding=config.num_conv_pos_embeddings // 2,
  49. groups=config.num_conv_pos_embedding_groups,
  50. )
  51. self.batch_norm = None
  52. if config.conv_pos_batch_norm:
  53. self.batch_norm = nn.BatchNorm1d(config.hidden_size)
  54. else:
  55. weight_norm = nn.utils.weight_norm
  56. if hasattr(nn.utils.parametrizations, "weight_norm"):
  57. weight_norm = nn.utils.parametrizations.weight_norm
  58. if is_deepspeed_zero3_enabled():
  59. import deepspeed
  60. with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
  61. self.conv = weight_norm(self.conv, name="weight", dim=2)
  62. if hasattr(self.conv, "parametrizations"):
  63. weight_g = self.conv.parametrizations.weight.original0
  64. weight_v = self.conv.parametrizations.weight.original1
  65. else:
  66. weight_g = self.conv.weight_g
  67. weight_v = self.conv.weight_v
  68. deepspeed.zero.register_external_parameter(self, weight_v)
  69. deepspeed.zero.register_external_parameter(self, weight_g)
  70. else:
  71. self.conv = weight_norm(self.conv, name="weight", dim=2)
  72. self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings)
  73. self.activation = ACT2FN[config.feat_extract_activation]
  74. def forward(self, hidden_states):
  75. hidden_states = hidden_states.transpose(1, 2)
  76. if self.batch_norm is not None:
  77. hidden_states = self.batch_norm(hidden_states)
  78. hidden_states = self.conv(hidden_states)
  79. hidden_states = self.padding(hidden_states)
  80. hidden_states = self.activation(hidden_states)
  81. hidden_states = hidden_states.transpose(1, 2)
  82. return hidden_states
  83. class HubertSamePadLayer(nn.Module):
  84. def __init__(self, num_conv_pos_embeddings):
  85. super().__init__()
  86. self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0
  87. def forward(self, hidden_states):
  88. if self.num_pad_remove > 0:
  89. hidden_states = hidden_states[:, :, : -self.num_pad_remove]
  90. return hidden_states
  91. class HubertNoLayerNormConvLayer(GradientCheckpointingLayer):
  92. def __init__(self, config, layer_id=0):
  93. super().__init__()
  94. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  95. self.out_conv_dim = config.conv_dim[layer_id]
  96. self.conv = nn.Conv1d(
  97. self.in_conv_dim,
  98. self.out_conv_dim,
  99. kernel_size=config.conv_kernel[layer_id],
  100. stride=config.conv_stride[layer_id],
  101. bias=config.conv_bias,
  102. )
  103. self.activation = ACT2FN[config.feat_extract_activation]
  104. def forward(self, hidden_states):
  105. hidden_states = self.conv(hidden_states)
  106. hidden_states = self.activation(hidden_states)
  107. return hidden_states
  108. class HubertLayerNormConvLayer(GradientCheckpointingLayer):
  109. def __init__(self, config, layer_id=0):
  110. super().__init__()
  111. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  112. self.out_conv_dim = config.conv_dim[layer_id]
  113. self.conv = nn.Conv1d(
  114. self.in_conv_dim,
  115. self.out_conv_dim,
  116. kernel_size=config.conv_kernel[layer_id],
  117. stride=config.conv_stride[layer_id],
  118. bias=config.conv_bias,
  119. )
  120. self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
  121. self.activation = ACT2FN[config.feat_extract_activation]
  122. def forward(self, hidden_states):
  123. hidden_states = self.conv(hidden_states)
  124. hidden_states = hidden_states.transpose(-2, -1)
  125. hidden_states = self.layer_norm(hidden_states)
  126. hidden_states = hidden_states.transpose(-2, -1)
  127. hidden_states = self.activation(hidden_states)
  128. return hidden_states
  129. class HubertGroupNormConvLayer(GradientCheckpointingLayer):
  130. def __init__(self, config, layer_id=0):
  131. super().__init__()
  132. self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1
  133. self.out_conv_dim = config.conv_dim[layer_id]
  134. self.conv = nn.Conv1d(
  135. self.in_conv_dim,
  136. self.out_conv_dim,
  137. kernel_size=config.conv_kernel[layer_id],
  138. stride=config.conv_stride[layer_id],
  139. bias=config.conv_bias,
  140. )
  141. self.activation = ACT2FN[config.feat_extract_activation]
  142. self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
  143. def forward(self, hidden_states):
  144. hidden_states = self.conv(hidden_states)
  145. hidden_states = self.layer_norm(hidden_states)
  146. hidden_states = self.activation(hidden_states)
  147. return hidden_states
  148. class HubertFeatureEncoder(nn.Module):
  149. """Construct the features from raw audio waveform"""
  150. def __init__(self, config):
  151. super().__init__()
  152. if config.feat_extract_norm == "group":
  153. conv_layers = [HubertGroupNormConvLayer(config, layer_id=0)] + [
  154. HubertNoLayerNormConvLayer(config, layer_id=i + 1) for i in range(config.num_feat_extract_layers - 1)
  155. ]
  156. elif config.feat_extract_norm == "layer":
  157. conv_layers = [HubertLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers)]
  158. else:
  159. raise ValueError(
  160. f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']"
  161. )
  162. self.conv_layers = nn.ModuleList(conv_layers)
  163. self.gradient_checkpointing = False
  164. self._requires_grad = True
  165. def _freeze_parameters(self):
  166. for param in self.parameters():
  167. param.requires_grad = False
  168. self._requires_grad = False
  169. def forward(self, input_values):
  170. hidden_states = input_values[:, None]
  171. # make sure hidden_states require grad for gradient_checkpointing
  172. if self._requires_grad and self.training:
  173. hidden_states.requires_grad = True
  174. for conv_layer in self.conv_layers:
  175. hidden_states = conv_layer(hidden_states)
  176. return hidden_states
  177. class HubertFeatureProjection(nn.Module):
  178. def __init__(self, config):
  179. super().__init__()
  180. self.feat_proj_layer_norm = config.feat_proj_layer_norm
  181. if self.feat_proj_layer_norm:
  182. self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
  183. self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
  184. self.dropout = nn.Dropout(config.feat_proj_dropout)
  185. def forward(self, hidden_states):
  186. # non-projected hidden states are needed for quantization
  187. if self.feat_proj_layer_norm:
  188. hidden_states = self.layer_norm(hidden_states)
  189. hidden_states = self.projection(hidden_states)
  190. hidden_states = self.dropout(hidden_states)
  191. return hidden_states
  192. def eager_attention_forward(
  193. module: nn.Module,
  194. query: torch.Tensor,
  195. key: torch.Tensor,
  196. value: torch.Tensor,
  197. attention_mask: Optional[torch.Tensor],
  198. scaling: Optional[float] = None,
  199. dropout: float = 0.0,
  200. head_mask: Optional[torch.Tensor] = None,
  201. **kwargs,
  202. ):
  203. if scaling is None:
  204. scaling = query.size(-1) ** -0.5
  205. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  206. if attention_mask is not None:
  207. attn_weights = attn_weights + attention_mask
  208. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  209. if head_mask is not None:
  210. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  211. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  212. attn_output = torch.matmul(attn_weights, value)
  213. attn_output = attn_output.transpose(1, 2).contiguous()
  214. return attn_output, attn_weights
  215. class HubertAttention(nn.Module):
  216. """Multi-headed attention from 'Attention Is All You Need' paper"""
  217. def __init__(
  218. self,
  219. embed_dim: int,
  220. num_heads: int,
  221. dropout: float = 0.0,
  222. is_decoder: bool = False,
  223. bias: bool = True,
  224. is_causal: bool = False,
  225. config: Optional[HubertConfig] = None,
  226. ):
  227. super().__init__()
  228. self.embed_dim = embed_dim
  229. self.num_heads = num_heads
  230. self.dropout = dropout
  231. self.head_dim = embed_dim // num_heads
  232. self.config = config
  233. if (self.head_dim * num_heads) != self.embed_dim:
  234. raise ValueError(
  235. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  236. f" and `num_heads`: {num_heads})."
  237. )
  238. self.scaling = self.head_dim**-0.5
  239. self.is_decoder = is_decoder
  240. self.is_causal = is_causal
  241. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  242. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  243. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  244. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  245. def forward(
  246. self,
  247. hidden_states: torch.Tensor,
  248. key_value_states: Optional[torch.Tensor] = None,
  249. attention_mask: Optional[torch.Tensor] = None,
  250. layer_head_mask: Optional[torch.Tensor] = None,
  251. output_attentions: Optional[bool] = False,
  252. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  253. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  254. **kwargs: Unpack[FlashAttentionKwargs],
  255. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  256. """Input shape: Batch x Time x Channel"""
  257. # if key_value_states are provided this layer is used as a cross-attention layer
  258. # for the decoder
  259. is_cross_attention = key_value_states is not None
  260. # determine input shapes
  261. bsz, tgt_len = hidden_states.shape[:-1]
  262. src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
  263. q_input_shape = (bsz, tgt_len, -1, self.head_dim)
  264. kv_input_shape = (bsz, src_len, -1, self.head_dim)
  265. # get query proj
  266. query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
  267. current_states = key_value_states if is_cross_attention else hidden_states
  268. key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2)
  269. value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2)
  270. attention_interface: Callable = eager_attention_forward
  271. if self.config._attn_implementation != "eager":
  272. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  273. attn_output, attn_weights = attention_interface(
  274. self,
  275. query_states,
  276. key_states,
  277. value_states,
  278. attention_mask,
  279. dropout=0.0 if not self.training else self.dropout,
  280. scaling=self.scaling,
  281. output_attentions=output_attentions,
  282. head_mask=layer_head_mask,
  283. **kwargs,
  284. )
  285. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  286. attn_output = self.out_proj(attn_output)
  287. return attn_output, attn_weights, None
  288. class HubertFeedForward(nn.Module):
  289. def __init__(self, config):
  290. super().__init__()
  291. self.intermediate_dropout = nn.Dropout(config.activation_dropout)
  292. self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
  293. if isinstance(config.hidden_act, str):
  294. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  295. else:
  296. self.intermediate_act_fn = config.hidden_act
  297. self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
  298. self.output_dropout = nn.Dropout(config.hidden_dropout)
  299. def forward(self, hidden_states):
  300. hidden_states = self.intermediate_dense(hidden_states)
  301. hidden_states = self.intermediate_act_fn(hidden_states)
  302. hidden_states = self.intermediate_dropout(hidden_states)
  303. hidden_states = self.output_dense(hidden_states)
  304. hidden_states = self.output_dropout(hidden_states)
  305. return hidden_states
  306. class HubertEncoderLayer(GradientCheckpointingLayer):
  307. def __init__(self, config):
  308. super().__init__()
  309. self.attention = HubertAttention(
  310. embed_dim=config.hidden_size,
  311. num_heads=config.num_attention_heads,
  312. dropout=config.attention_dropout,
  313. is_decoder=False,
  314. config=config,
  315. )
  316. self.dropout = nn.Dropout(config.hidden_dropout)
  317. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  318. self.feed_forward = HubertFeedForward(config)
  319. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  320. def forward(self, hidden_states, attention_mask=None, output_attentions=False):
  321. attn_residual = hidden_states
  322. hidden_states, attn_weights, _ = self.attention(
  323. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  324. )
  325. hidden_states = self.dropout(hidden_states)
  326. hidden_states = attn_residual + hidden_states
  327. hidden_states = self.layer_norm(hidden_states)
  328. hidden_states = hidden_states + self.feed_forward(hidden_states)
  329. hidden_states = self.final_layer_norm(hidden_states)
  330. outputs = (hidden_states,)
  331. if output_attentions:
  332. outputs += (attn_weights,)
  333. return outputs
  334. class HubertEncoder(nn.Module):
  335. def __init__(self, config):
  336. super().__init__()
  337. self.config = config
  338. self.pos_conv_embed = HubertPositionalConvEmbedding(config)
  339. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  340. self.dropout = nn.Dropout(config.hidden_dropout)
  341. self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  342. self.gradient_checkpointing = False
  343. def forward(
  344. self,
  345. hidden_states: torch.tensor,
  346. attention_mask: Optional[torch.Tensor] = None,
  347. output_attentions: bool = False,
  348. output_hidden_states: bool = False,
  349. return_dict: bool = True,
  350. ):
  351. all_hidden_states = () if output_hidden_states else None
  352. all_self_attentions = () if output_attentions else None
  353. if attention_mask is not None:
  354. # make sure padded tokens output 0
  355. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  356. hidden_states[~expand_attention_mask] = 0
  357. attention_mask = self._update_full_mask(
  358. attention_mask,
  359. hidden_states,
  360. )
  361. position_embeddings = self.pos_conv_embed(hidden_states)
  362. hidden_states = hidden_states + position_embeddings
  363. hidden_states = self.layer_norm(hidden_states)
  364. hidden_states = self.dropout(hidden_states)
  365. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  366. for layer in self.layers:
  367. if output_hidden_states:
  368. all_hidden_states = all_hidden_states + (hidden_states,)
  369. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  370. dropout_probability = torch.rand([])
  371. skip_the_layer = self.training and dropout_probability < self.config.layerdrop
  372. if not skip_the_layer or synced_gpus:
  373. # under fsdp or deepspeed zero3 all gpus must run in sync
  374. layer_outputs = layer(
  375. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  376. )
  377. hidden_states = layer_outputs[0]
  378. if skip_the_layer:
  379. layer_outputs = (None, None)
  380. if output_attentions:
  381. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  382. if output_hidden_states:
  383. all_hidden_states = all_hidden_states + (hidden_states,)
  384. if not return_dict:
  385. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  386. return BaseModelOutput(
  387. last_hidden_state=hidden_states,
  388. hidden_states=all_hidden_states,
  389. attentions=all_self_attentions,
  390. )
  391. def _update_full_mask(
  392. self,
  393. attention_mask: Union[torch.Tensor, None],
  394. inputs_embeds: torch.Tensor,
  395. ):
  396. if attention_mask is not None:
  397. if self.config._attn_implementation == "flash_attention_2":
  398. attention_mask = attention_mask if 0 in attention_mask else None
  399. elif self.config._attn_implementation == "sdpa":
  400. # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
  401. # the manual implementation that requires a 4D causal mask in all cases.
  402. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  403. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
  404. elif self.config._attn_implementation == "flex_attention":
  405. if isinstance(attention_mask, torch.Tensor):
  406. attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
  407. else:
  408. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  409. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  410. return attention_mask
  411. class HubertAttnAdapterLayer(nn.Module):
  412. def __init__(self, config):
  413. """
  414. Implements adapter modules directly with 3D tensor weight as parameters and without using ModuleList to speed
  415. up training throughput.
  416. """
  417. super().__init__()
  418. self.input_dim = config.adapter_attn_dim
  419. self.hidden_dim = config.hidden_size
  420. self.norm = nn.LayerNorm(self.hidden_dim)
  421. self.linear_1 = nn.Linear(self.hidden_dim, self.input_dim)
  422. self.act_fn = nn.ReLU()
  423. self.linear_2 = nn.Linear(self.input_dim, self.hidden_dim)
  424. def forward(self, hidden_states: torch.FloatTensor):
  425. hidden_states = self.norm(hidden_states)
  426. hidden_states = self.linear_1(hidden_states)
  427. hidden_states = self.act_fn(hidden_states)
  428. hidden_states = self.linear_2(hidden_states)
  429. return hidden_states
  430. class HubertEncoderLayerStableLayerNorm(GradientCheckpointingLayer):
  431. def __init__(self, config):
  432. super().__init__()
  433. self.attention = HubertAttention(
  434. embed_dim=config.hidden_size,
  435. num_heads=config.num_attention_heads,
  436. dropout=config.attention_dropout,
  437. is_decoder=False,
  438. config=config,
  439. )
  440. self.dropout = nn.Dropout(config.hidden_dropout)
  441. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  442. self.feed_forward = HubertFeedForward(config)
  443. self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  444. if getattr(config, "adapter_attn_dim", None) is not None:
  445. self.adapter_layer = HubertAttnAdapterLayer(config)
  446. else:
  447. self.adapter_layer = None
  448. def forward(
  449. self,
  450. hidden_states: torch.Tensor,
  451. attention_mask: Optional[torch.Tensor] = None,
  452. output_attentions: bool = False,
  453. ):
  454. attn_residual = hidden_states
  455. hidden_states = self.layer_norm(hidden_states)
  456. hidden_states, attn_weights, _ = self.attention(
  457. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  458. )
  459. hidden_states = self.dropout(hidden_states)
  460. hidden_states = attn_residual + hidden_states
  461. hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
  462. if self.adapter_layer is not None:
  463. hidden_states = hidden_states + self.adapter_layer(hidden_states)
  464. outputs = (hidden_states,)
  465. if output_attentions:
  466. outputs += (attn_weights,)
  467. return outputs
  468. class HubertEncoderStableLayerNorm(nn.Module):
  469. def __init__(self, config):
  470. super().__init__()
  471. self.config = config
  472. self.pos_conv_embed = HubertPositionalConvEmbedding(config)
  473. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  474. self.dropout = nn.Dropout(config.hidden_dropout)
  475. self.layers = nn.ModuleList(
  476. [HubertEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
  477. )
  478. self.gradient_checkpointing = False
  479. def forward(
  480. self,
  481. hidden_states,
  482. attention_mask=None,
  483. output_attentions=False,
  484. output_hidden_states=False,
  485. return_dict=True,
  486. ):
  487. all_hidden_states = () if output_hidden_states else None
  488. all_self_attentions = () if output_attentions else None
  489. if attention_mask is not None:
  490. # make sure padded tokens output 0
  491. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  492. hidden_states[~expand_attention_mask] = 0
  493. attention_mask = self._update_full_mask(
  494. attention_mask,
  495. hidden_states,
  496. )
  497. position_embeddings = self.pos_conv_embed(hidden_states)
  498. hidden_states = hidden_states + position_embeddings
  499. hidden_states = self.dropout(hidden_states)
  500. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  501. for layer in self.layers:
  502. if output_hidden_states:
  503. all_hidden_states = all_hidden_states + (hidden_states,)
  504. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  505. dropout_probability = torch.rand([])
  506. skip_the_layer = self.training and dropout_probability < self.config.layerdrop
  507. if not skip_the_layer or synced_gpus:
  508. # under fsdp or deepspeed zero3 all gpus must run in sync
  509. # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
  510. layer_outputs = layer(
  511. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  512. )
  513. hidden_states = layer_outputs[0]
  514. if skip_the_layer:
  515. layer_outputs = (None, None)
  516. if output_attentions:
  517. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  518. hidden_states = self.layer_norm(hidden_states)
  519. if output_hidden_states:
  520. all_hidden_states = all_hidden_states + (hidden_states,)
  521. if not return_dict:
  522. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  523. return BaseModelOutput(
  524. last_hidden_state=hidden_states,
  525. hidden_states=all_hidden_states,
  526. attentions=all_self_attentions,
  527. )
  528. def _update_full_mask(
  529. self,
  530. attention_mask: Union[torch.Tensor, None],
  531. inputs_embeds: torch.Tensor,
  532. ):
  533. if attention_mask is not None:
  534. if self.config._attn_implementation == "flash_attention_2":
  535. attention_mask = attention_mask if 0 in attention_mask else None
  536. elif self.config._attn_implementation == "sdpa":
  537. # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
  538. # the manual implementation that requires a 4D causal mask in all cases.
  539. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  540. attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
  541. elif self.config._attn_implementation == "flex_attention":
  542. if isinstance(attention_mask, torch.Tensor):
  543. attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
  544. else:
  545. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  546. attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
  547. return attention_mask
  548. @auto_docstring
  549. class HubertPreTrainedModel(PreTrainedModel):
  550. config: HubertConfig
  551. base_model_prefix = "hubert"
  552. main_input_name = "input_values"
  553. supports_gradient_checkpointing = True
  554. _supports_flash_attn = True
  555. _supports_sdpa = True
  556. _supports_flex_attn = True
  557. def _init_weights(self, module):
  558. """Initialize the weights"""
  559. if isinstance(module, nn.Linear):
  560. # Slightly different from the TF version which uses truncated_normal for initialization
  561. # cf https://github.com/pytorch/pytorch/pull/5617
  562. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  563. if module.bias is not None:
  564. module.bias.data.zero_()
  565. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)):
  566. module.bias.data.zero_()
  567. module.weight.data.fill_(1.0)
  568. elif isinstance(module, nn.Conv1d):
  569. if is_deepspeed_zero3_enabled():
  570. import deepspeed
  571. if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
  572. with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
  573. nn.init.kaiming_normal_(module.weight.data)
  574. else:
  575. with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
  576. nn.init.kaiming_normal_(module.weight.data)
  577. else:
  578. nn.init.kaiming_normal_(module.weight.data)
  579. if module.bias is not None:
  580. module.bias.data.zero_()
  581. elif isinstance(module, HubertModel):
  582. if hasattr(module, "masked_spec_embed"):
  583. module.masked_spec_embed.data.uniform_()
  584. elif isinstance(module, HubertForSequenceClassification):
  585. if hasattr(module, "layer_weights"):
  586. module.layer_weights.data.fill_(1.0 / (self.config.num_hidden_layers + 1))
  587. def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
  588. """
  589. Computes the output length of the convolutional layers
  590. """
  591. def _conv_out_length(input_length, kernel_size, stride):
  592. # 1D convolutional layer output length formula taken
  593. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  594. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  595. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  596. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  597. return input_lengths
  598. def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
  599. output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  600. batch_size = attention_mask.shape[0]
  601. attention_mask = torch.zeros(
  602. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  603. )
  604. # these two operations makes sure that all values before the output lengths idxs are attended to
  605. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  606. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  607. return attention_mask
  608. def _compute_mask_indices(
  609. shape: tuple[int, int],
  610. mask_prob: float,
  611. mask_length: int,
  612. attention_mask: Optional[torch.LongTensor] = None,
  613. min_masks: int = 0,
  614. ) -> np.ndarray:
  615. """
  616. Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
  617. ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
  618. CPU as part of the preprocessing during training.
  619. Args:
  620. shape: The shape for which to compute masks. This should be of a tuple of size 2 where
  621. the first element is the batch size and the second element is the length of the axis to span.
  622. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
  623. independently generated mask spans of length `mask_length` is computed by
  624. `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
  625. actual percentage will be smaller.
  626. mask_length: size of the mask
  627. min_masks: minimum number of masked spans
  628. attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
  629. each batch dimension.
  630. """
  631. batch_size, sequence_length = shape
  632. if mask_length < 1:
  633. raise ValueError("`mask_length` has to be bigger than 0.")
  634. if mask_length > sequence_length:
  635. raise ValueError(
  636. f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
  637. f" and `sequence_length`: {sequence_length}`"
  638. )
  639. # epsilon is used for probabilistic rounding
  640. epsilon = np.random.rand(1).item()
  641. def compute_num_masked_span(input_length):
  642. """Given input length, compute how many spans should be masked"""
  643. num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
  644. num_masked_span = max(num_masked_span, min_masks)
  645. # make sure num masked span <= sequence_length
  646. if num_masked_span * mask_length > sequence_length:
  647. num_masked_span = sequence_length // mask_length
  648. # make sure num_masked span is also <= input_length - (mask_length - 1)
  649. if input_length - (mask_length - 1) < num_masked_span:
  650. num_masked_span = max(input_length - (mask_length - 1), 0)
  651. return num_masked_span
  652. # compute number of masked spans in batch
  653. input_lengths = (
  654. attention_mask.detach().sum(-1).tolist()
  655. if attention_mask is not None
  656. else [sequence_length for _ in range(batch_size)]
  657. )
  658. # SpecAugment mask to fill
  659. spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
  660. spec_aug_mask_idxs = []
  661. max_num_masked_span = compute_num_masked_span(sequence_length)
  662. if max_num_masked_span == 0:
  663. return spec_aug_mask
  664. for input_length in input_lengths:
  665. # compute num of masked spans for this input
  666. num_masked_span = compute_num_masked_span(input_length)
  667. # get random indices to mask
  668. spec_aug_mask_idx = np.random.choice(
  669. np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
  670. )
  671. # pick first sampled index that will serve as a dummy index to pad vector
  672. # to ensure same dimension for all batches due to probabilistic rounding
  673. # Picking first sample just pads those vectors twice.
  674. if len(spec_aug_mask_idx) == 0:
  675. # this case can only happen if `input_length` is strictly smaller then
  676. # `sequence_length` in which case the last token has to be a padding
  677. # token which we can use as a dummy mask id
  678. dummy_mask_idx = sequence_length - 1
  679. else:
  680. dummy_mask_idx = spec_aug_mask_idx[0]
  681. spec_aug_mask_idx = np.concatenate(
  682. [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
  683. )
  684. spec_aug_mask_idxs.append(spec_aug_mask_idx)
  685. spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
  686. # expand masked indices to masked spans
  687. spec_aug_mask_idxs = np.broadcast_to(
  688. spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
  689. )
  690. spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
  691. # add offset to the starting indexes so that indexes now create a span
  692. offsets = np.arange(mask_length)[None, None, :]
  693. offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
  694. batch_size, max_num_masked_span * mask_length
  695. )
  696. spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
  697. # ensure that we cannot have indices larger than sequence_length
  698. if spec_aug_mask_idxs.max() > sequence_length - 1:
  699. spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
  700. # scatter indices to mask
  701. np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
  702. return spec_aug_mask
  703. @auto_docstring
  704. class HubertModel(HubertPreTrainedModel):
  705. def __init__(self, config: HubertConfig):
  706. super().__init__(config)
  707. self.config = config
  708. self.feature_extractor = HubertFeatureEncoder(config)
  709. self.feature_projection = HubertFeatureProjection(config)
  710. # model only needs masking vector if mask prob is > 0.0
  711. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  712. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  713. if config.do_stable_layer_norm:
  714. self.encoder = HubertEncoderStableLayerNorm(config)
  715. else:
  716. self.encoder = HubertEncoder(config)
  717. # Initialize weights and apply final processing
  718. self.post_init()
  719. def _mask_hidden_states(
  720. self,
  721. hidden_states: torch.FloatTensor,
  722. mask_time_indices: Optional[torch.FloatTensor] = None,
  723. attention_mask: Optional[torch.LongTensor] = None,
  724. ):
  725. """
  726. Masks extracted features along time axis and/or along feature axis according to
  727. [SpecAugment](https://huggingface.co/papers/1904.08779).
  728. """
  729. # `config.apply_spec_augment` can set masking to False
  730. if not getattr(self.config, "apply_spec_augment", True):
  731. return hidden_states
  732. # generate indices & apply SpecAugment along time axis
  733. batch_size, sequence_length, hidden_size = hidden_states.size()
  734. if mask_time_indices is not None:
  735. # apply SpecAugment along time axis with given mask_time_indices
  736. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  737. elif self.config.mask_time_prob > 0 and self.training:
  738. mask_time_indices = _compute_mask_indices(
  739. (batch_size, sequence_length),
  740. mask_prob=self.config.mask_time_prob,
  741. mask_length=self.config.mask_time_length,
  742. attention_mask=attention_mask,
  743. min_masks=self.config.mask_time_min_masks,
  744. )
  745. mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
  746. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  747. if self.config.mask_feature_prob > 0 and self.training:
  748. # generate indices & apply SpecAugment along feature axis
  749. mask_feature_indices = _compute_mask_indices(
  750. (batch_size, hidden_size),
  751. mask_prob=self.config.mask_feature_prob,
  752. mask_length=self.config.mask_feature_length,
  753. min_masks=self.config.mask_feature_min_masks,
  754. )
  755. mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
  756. mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
  757. hidden_states[mask_feature_indices] = 0
  758. return hidden_states
  759. @auto_docstring
  760. def forward(
  761. self,
  762. input_values: Optional[torch.Tensor],
  763. attention_mask: Optional[torch.Tensor] = None,
  764. mask_time_indices: Optional[torch.FloatTensor] = None,
  765. output_attentions: Optional[bool] = None,
  766. output_hidden_states: Optional[bool] = None,
  767. return_dict: Optional[bool] = None,
  768. ) -> Union[tuple, BaseModelOutput]:
  769. r"""
  770. mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  771. Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
  772. masked extracted features in *config.proj_codevector_dim* space.
  773. Example:
  774. ```python
  775. >>> from transformers import AutoProcessor, HubertModel
  776. >>> from datasets import load_dataset
  777. >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
  778. >>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
  779. >>> def map_to_array(example):
  780. ... example["speech"] = example["audio"]["array"]
  781. ... return example
  782. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  783. >>> ds = ds.map(map_to_array)
  784. >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
  785. >>> hidden_states = model(input_values).last_hidden_state
  786. ```"""
  787. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  788. output_hidden_states = (
  789. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  790. )
  791. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  792. extract_features = self.feature_extractor(input_values)
  793. extract_features = extract_features.transpose(1, 2)
  794. if attention_mask is not None:
  795. # compute reduced attention_mask corresponding to feature vectors
  796. attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
  797. hidden_states = self.feature_projection(extract_features)
  798. hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
  799. encoder_outputs = self.encoder(
  800. hidden_states,
  801. attention_mask=attention_mask,
  802. output_attentions=output_attentions,
  803. output_hidden_states=output_hidden_states,
  804. return_dict=return_dict,
  805. )
  806. hidden_states = encoder_outputs[0]
  807. if not return_dict:
  808. return (hidden_states,) + encoder_outputs[1:]
  809. return BaseModelOutput(
  810. last_hidden_state=hidden_states,
  811. hidden_states=encoder_outputs.hidden_states,
  812. attentions=encoder_outputs.attentions,
  813. )
  814. _HIDDEN_STATES_START_POSITION = 1
  815. @auto_docstring(
  816. custom_intro="""
  817. Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
  818. """
  819. )
  820. class HubertForCTC(HubertPreTrainedModel):
  821. def __init__(self, config, target_lang: Optional[str] = None):
  822. r"""
  823. target_lang (`str`, *optional*):
  824. Language id of adapter weights. Adapter weights are stored in the format adapter.<lang>.safetensors or
  825. adapter.<lang>.bin. Only relevant when using an instance of [`HubertForCTC`] with adapters. Uses 'eng' by
  826. default.
  827. """
  828. super().__init__(config)
  829. self.hubert = HubertModel(config)
  830. self.dropout = nn.Dropout(config.final_dropout)
  831. self.target_lang = target_lang
  832. if config.vocab_size is None:
  833. raise ValueError(
  834. f"You are trying to instantiate {self.__class__} with a configuration that "
  835. "does not define the vocabulary size of the language model head. Please "
  836. "instantiate the model as follows: `HubertForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
  837. "or define `vocab_size` of your model's configuration."
  838. )
  839. output_hidden_size = (
  840. config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
  841. )
  842. self.lm_head = nn.Linear(output_hidden_size, config.vocab_size)
  843. # Initialize weights and apply final processing
  844. self.post_init()
  845. def tie_weights(self):
  846. """
  847. This method overwrites [`~PreTrainedModel.tie_weights`] so that adapter weights can be correctly loaded when
  848. passing `target_lang=...` to `from_pretrained(...)`.
  849. This method is **not** supposed to be called by the user and is prone to be changed in the future.
  850. """
  851. # Note that `tie_weights` is usually used to tie input and output embedding weights. The method is re-purposed to
  852. # correctly load adapter layers for Hubert so that we do not have to introduce a new API to
  853. # [`PreTrainedModel`]. While slightly hacky, Hubert never has to tie input and output embeddings, so that it is
  854. # ok to repurpose this function here.
  855. target_lang = self.target_lang
  856. if target_lang is not None and getattr(self.config, "adapter_attn_dim", None) is None:
  857. raise ValueError(f"Cannot pass `target_lang`: {target_lang} if `config.adapter_attn_dim` is not defined.")
  858. elif target_lang is None and getattr(self.config, "adapter_attn_dim", None) is not None:
  859. logger.info("By default `target_lang` is set to 'eng'.")
  860. elif target_lang is not None:
  861. self.load_adapter(target_lang, force_load=True)
  862. def freeze_feature_extractor(self):
  863. """
  864. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  865. not be updated during training.
  866. """
  867. warnings.warn(
  868. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  869. "Please use the equivalent `freeze_feature_encoder` method instead.",
  870. FutureWarning,
  871. )
  872. self.freeze_feature_encoder()
  873. def freeze_feature_encoder(self):
  874. """
  875. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  876. not be updated during training.
  877. """
  878. self.hubert.feature_extractor._freeze_parameters()
  879. def freeze_base_model(self):
  880. """
  881. Calling this function will disable the gradient computation for the base model so that its parameters will not
  882. be updated during training. Only the classification head will be updated.
  883. """
  884. for param in self.hubert.parameters():
  885. param.requires_grad = False
  886. @auto_docstring
  887. def forward(
  888. self,
  889. input_values: Optional[torch.Tensor],
  890. attention_mask: Optional[torch.Tensor] = None,
  891. output_attentions: Optional[bool] = None,
  892. output_hidden_states: Optional[bool] = None,
  893. return_dict: Optional[bool] = None,
  894. labels: Optional[torch.Tensor] = None,
  895. ) -> Union[tuple, CausalLMOutput]:
  896. r"""
  897. labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
  898. Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
  899. the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
  900. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
  901. config.vocab_size - 1]`.
  902. """
  903. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  904. if labels is not None and labels.max() >= self.config.vocab_size:
  905. raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
  906. outputs = self.hubert(
  907. input_values,
  908. attention_mask=attention_mask,
  909. output_attentions=output_attentions,
  910. output_hidden_states=output_hidden_states,
  911. return_dict=return_dict,
  912. )
  913. hidden_states = outputs[0]
  914. hidden_states = self.dropout(hidden_states)
  915. logits = self.lm_head(hidden_states)
  916. loss = None
  917. if labels is not None:
  918. # retrieve loss input_lengths from attention_mask
  919. attention_mask = (
  920. attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
  921. )
  922. input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  923. # assuming that padded tokens are filled with -100
  924. # when not being attended to
  925. labels_mask = labels >= 0
  926. target_lengths = labels_mask.sum(-1)
  927. flattened_targets = labels.masked_select(labels_mask)
  928. # ctc_loss doesn't support fp16
  929. log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
  930. with torch.backends.cudnn.flags(enabled=False):
  931. loss = nn.functional.ctc_loss(
  932. log_probs,
  933. flattened_targets,
  934. input_lengths,
  935. target_lengths,
  936. blank=self.config.pad_token_id,
  937. reduction=self.config.ctc_loss_reduction,
  938. zero_infinity=self.config.ctc_zero_infinity,
  939. )
  940. if not return_dict:
  941. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  942. return ((loss,) + output) if loss is not None else output
  943. return CausalLMOutput(
  944. loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  945. )
  946. @auto_docstring(
  947. custom_intro="""
  948. Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
  949. SUPERB Keyword Spotting.
  950. """
  951. )
  952. class HubertForSequenceClassification(HubertPreTrainedModel):
  953. def __init__(self, config):
  954. super().__init__(config)
  955. if hasattr(config, "add_adapter") and config.add_adapter:
  956. raise ValueError(
  957. "Sequence classification does not support the use of Hubert adapters (config.add_adapter=True)"
  958. )
  959. self.hubert = HubertModel(config)
  960. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  961. if config.use_weighted_layer_sum:
  962. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  963. self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
  964. self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
  965. # Initialize weights and apply final processing
  966. self.post_init()
  967. def freeze_feature_extractor(self):
  968. """
  969. Calling this function will disable the gradient computation for the feature encoder so that its parameters will
  970. not be updated during training.
  971. """
  972. warnings.warn(
  973. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  974. "Please use the equivalent `freeze_feature_encoder` method instead.",
  975. FutureWarning,
  976. )
  977. self.freeze_feature_encoder()
  978. def freeze_feature_encoder(self):
  979. """
  980. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  981. not be updated during training.
  982. """
  983. self.hubert.feature_extractor._freeze_parameters()
  984. def freeze_base_model(self):
  985. """
  986. Calling this function will disable the gradient computation for the base model so that its parameters will not
  987. be updated during training. Only the classification head will be updated.
  988. """
  989. for param in self.hubert.parameters():
  990. param.requires_grad = False
  991. @auto_docstring
  992. def forward(
  993. self,
  994. input_values: Optional[torch.Tensor],
  995. attention_mask: Optional[torch.Tensor] = None,
  996. output_attentions: Optional[bool] = None,
  997. output_hidden_states: Optional[bool] = None,
  998. return_dict: Optional[bool] = None,
  999. labels: Optional[torch.Tensor] = None,
  1000. ) -> Union[tuple, SequenceClassifierOutput]:
  1001. r"""
  1002. input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  1003. Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file
  1004. into an array of type `list[float]`, a `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library
  1005. (`pip install torchcodec`) or the soundfile library (`pip install soundfile`).
  1006. To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and conversion
  1007. into a tensor of type `torch.FloatTensor`. See [`HubertProcessor.__call__`] for details.
  1008. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1009. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1010. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1011. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1012. """
  1013. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1014. output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
  1015. outputs = self.hubert(
  1016. input_values,
  1017. attention_mask=attention_mask,
  1018. output_attentions=output_attentions,
  1019. output_hidden_states=output_hidden_states,
  1020. return_dict=return_dict,
  1021. )
  1022. if self.config.use_weighted_layer_sum:
  1023. hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
  1024. hidden_states = torch.stack(hidden_states, dim=1)
  1025. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1026. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1027. else:
  1028. hidden_states = outputs[0]
  1029. hidden_states = self.projector(hidden_states)
  1030. if attention_mask is None:
  1031. pooled_output = hidden_states.mean(dim=1)
  1032. else:
  1033. padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
  1034. expand_padding_mask = padding_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  1035. hidden_states[~expand_padding_mask] = 0.0
  1036. pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
  1037. logits = self.classifier(pooled_output)
  1038. loss = None
  1039. if labels is not None:
  1040. loss_fct = CrossEntropyLoss()
  1041. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1042. if not return_dict:
  1043. output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
  1044. return ((loss,) + output) if loss is not None else output
  1045. return SequenceClassifierOutput(
  1046. loss=loss,
  1047. logits=logits,
  1048. hidden_states=outputs.hidden_states,
  1049. attentions=outputs.attentions,
  1050. )
  1051. __all__ = ["HubertForCTC", "HubertForSequenceClassification", "HubertModel", "HubertPreTrainedModel"]