modeling_whisper.py 72 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630
  1. # coding=utf-8
  2. # Copyright 2022 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Whisper model."""
  16. import math
  17. from typing import Callable, Optional, Union
  18. import numpy as np
  19. import torch
  20. from torch import nn
  21. from torch.nn import CrossEntropyLoss
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  24. from ...generation import GenerationMixin
  25. from ...masking_utils import create_causal_mask
  26. from ...modeling_flash_attention_utils import (
  27. FlashAttentionKwargs,
  28. )
  29. from ...modeling_layers import GradientCheckpointingLayer
  30. from ...modeling_outputs import (
  31. BaseModelOutput,
  32. BaseModelOutputWithPastAndCrossAttentions,
  33. CausalLMOutputWithCrossAttentions,
  34. Seq2SeqLMOutput,
  35. Seq2SeqModelOutput,
  36. SequenceClassifierOutput,
  37. )
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import auto_docstring, logging
  41. from ...utils.deprecation import deprecate_kwarg
  42. from .configuration_whisper import WhisperConfig
  43. from .generation_whisper import WhisperGenerationMixin
  44. logger = logging.get_logger(__name__)
  45. _HIDDEN_STATES_START_POSITION = 1
  46. def sinusoids(length: int, channels: int, max_timescale: float = 10000) -> torch.Tensor:
  47. """Returns sinusoids for positional embedding"""
  48. if channels % 2 != 0:
  49. raise ValueError(
  50. f"Number of channels has to be divisible by 2 for sinusoidal positional embeddings, got {channels} channels."
  51. )
  52. log_timescale_increment = math.log(max_timescale) / (channels // 2 - 1)
  53. inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
  54. scaled_time = torch.arange(length).view(-1, 1) * inv_timescales.view(1, -1)
  55. return torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
  56. # Copied from transformers.models.bart.modeling_bart.shift_tokens_right
  57. def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
  58. """
  59. Shift input ids one token to the right.
  60. """
  61. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  62. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  63. shifted_input_ids[:, 0] = decoder_start_token_id
  64. if pad_token_id is None:
  65. raise ValueError("self.model.config.pad_token_id has to be defined.")
  66. # replace possible -100 values in labels by `pad_token_id`
  67. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  68. return shifted_input_ids
  69. # Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
  70. def _compute_mask_indices(
  71. shape: tuple[int, int],
  72. mask_prob: float,
  73. mask_length: int,
  74. attention_mask: Optional[torch.LongTensor] = None,
  75. min_masks: int = 0,
  76. ) -> np.ndarray:
  77. """
  78. Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
  79. ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
  80. CPU as part of the preprocessing during training.
  81. Args:
  82. shape: The shape for which to compute masks. This should be of a tuple of size 2 where
  83. the first element is the batch size and the second element is the length of the axis to span.
  84. mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
  85. independently generated mask spans of length `mask_length` is computed by
  86. `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
  87. actual percentage will be smaller.
  88. mask_length: size of the mask
  89. min_masks: minimum number of masked spans
  90. attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
  91. each batch dimension.
  92. """
  93. batch_size, sequence_length = shape
  94. if mask_length < 1:
  95. raise ValueError("`mask_length` has to be bigger than 0.")
  96. if mask_length > sequence_length:
  97. raise ValueError(
  98. f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
  99. f" and `sequence_length`: {sequence_length}`"
  100. )
  101. # epsilon is used for probabilistic rounding
  102. epsilon = np.random.rand(1).item()
  103. def compute_num_masked_span(input_length):
  104. """Given input length, compute how many spans should be masked"""
  105. num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
  106. num_masked_span = max(num_masked_span, min_masks)
  107. # make sure num masked span <= sequence_length
  108. if num_masked_span * mask_length > sequence_length:
  109. num_masked_span = sequence_length // mask_length
  110. # make sure num_masked span is also <= input_length - (mask_length - 1)
  111. if input_length - (mask_length - 1) < num_masked_span:
  112. num_masked_span = max(input_length - (mask_length - 1), 0)
  113. return num_masked_span
  114. # compute number of masked spans in batch
  115. input_lengths = (
  116. attention_mask.detach().sum(-1).tolist()
  117. if attention_mask is not None
  118. else [sequence_length for _ in range(batch_size)]
  119. )
  120. # SpecAugment mask to fill
  121. spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
  122. spec_aug_mask_idxs = []
  123. max_num_masked_span = compute_num_masked_span(sequence_length)
  124. if max_num_masked_span == 0:
  125. return spec_aug_mask
  126. for input_length in input_lengths:
  127. # compute num of masked spans for this input
  128. num_masked_span = compute_num_masked_span(input_length)
  129. # get random indices to mask
  130. spec_aug_mask_idx = np.random.choice(
  131. np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
  132. )
  133. # pick first sampled index that will serve as a dummy index to pad vector
  134. # to ensure same dimension for all batches due to probabilistic rounding
  135. # Picking first sample just pads those vectors twice.
  136. if len(spec_aug_mask_idx) == 0:
  137. # this case can only happen if `input_length` is strictly smaller then
  138. # `sequence_length` in which case the last token has to be a padding
  139. # token which we can use as a dummy mask id
  140. dummy_mask_idx = sequence_length - 1
  141. else:
  142. dummy_mask_idx = spec_aug_mask_idx[0]
  143. spec_aug_mask_idx = np.concatenate(
  144. [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
  145. )
  146. spec_aug_mask_idxs.append(spec_aug_mask_idx)
  147. spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
  148. # expand masked indices to masked spans
  149. spec_aug_mask_idxs = np.broadcast_to(
  150. spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
  151. )
  152. spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
  153. # add offset to the starting indexes so that indexes now create a span
  154. offsets = np.arange(mask_length)[None, None, :]
  155. offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
  156. batch_size, max_num_masked_span * mask_length
  157. )
  158. spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
  159. # ensure that we cannot have indices larger than sequence_length
  160. if spec_aug_mask_idxs.max() > sequence_length - 1:
  161. spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
  162. # scatter indices to mask
  163. np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
  164. return spec_aug_mask
  165. class WhisperPositionalEmbedding(nn.Embedding):
  166. def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
  167. super().__init__(num_positions, embedding_dim)
  168. def forward(self, input_ids, past_key_values_length=0, position_ids=None):
  169. if position_ids is None:
  170. return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
  171. else:
  172. return self.weight[position_ids]
  173. def eager_attention_forward(
  174. module: nn.Module,
  175. query: torch.Tensor,
  176. key: torch.Tensor,
  177. value: torch.Tensor,
  178. attention_mask: Optional[torch.Tensor],
  179. scaling: Optional[float] = None,
  180. dropout: float = 0.0,
  181. head_mask: Optional[torch.Tensor] = None,
  182. **kwargs,
  183. ):
  184. if scaling is None:
  185. scaling = query.size(-1) ** -0.5
  186. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  187. if attention_mask is not None and attention_mask.ndim == 4:
  188. attn_weights = attn_weights + attention_mask[:, :, :, : key.shape[-2]]
  189. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  190. if head_mask is not None:
  191. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  192. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  193. attn_output = torch.matmul(attn_weights, value)
  194. attn_output = attn_output.transpose(1, 2).contiguous()
  195. return attn_output, attn_weights
  196. class WhisperAttention(nn.Module):
  197. """Multi-headed attention from 'Attention Is All You Need' paper"""
  198. def __init__(
  199. self,
  200. embed_dim: int,
  201. num_heads: int,
  202. dropout: float = 0.0,
  203. is_decoder: bool = False,
  204. bias: bool = True,
  205. is_causal: bool = False,
  206. layer_idx: Optional[int] = None,
  207. config: Optional[WhisperConfig] = None,
  208. ):
  209. super().__init__()
  210. self.embed_dim = embed_dim
  211. self.num_heads = num_heads
  212. self.dropout = dropout
  213. self.head_dim = embed_dim // num_heads
  214. self.config = config
  215. if (self.head_dim * num_heads) != self.embed_dim:
  216. raise ValueError(
  217. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  218. f" and `num_heads`: {num_heads})."
  219. )
  220. self.scaling = self.head_dim**-0.5
  221. self.is_decoder = is_decoder
  222. self.is_causal = is_causal
  223. if layer_idx is None and is_decoder:
  224. logger.warning_once(
  225. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  226. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  227. "when creating this class."
  228. )
  229. self.layer_idx = layer_idx
  230. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
  231. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  232. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  233. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  234. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  235. def forward(
  236. self,
  237. hidden_states: torch.Tensor,
  238. key_value_states: Optional[torch.Tensor] = None,
  239. past_key_values: Optional[Cache] = None,
  240. attention_mask: Optional[torch.Tensor] = None,
  241. layer_head_mask: Optional[torch.Tensor] = None,
  242. output_attentions: bool = False,
  243. cache_position: Optional[torch.Tensor] = None,
  244. # TODO: we need a refactor so that the different attention modules can get their specific kwargs
  245. # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
  246. **kwargs: Unpack[FlashAttentionKwargs],
  247. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  248. """Input shape: Batch x Time x Channel"""
  249. # if key_value_states are provided this layer is used as a cross-attention layer
  250. # for the decoder
  251. is_cross_attention = key_value_states is not None
  252. # determine input shapes
  253. bsz, tgt_len = hidden_states.shape[:-1]
  254. q_input_shape = (bsz, tgt_len, -1, self.head_dim)
  255. # Scaling is susceptible to floating point arithmetics' inprecisions
  256. # which can lead to different results (this is dependent from model
  257. # to model, e.g. whisper is one such case). We therefore keep the
  258. # original order of scaling to follow the original implementation
  259. # and enforce no scaling (1.0) in the attention call below.
  260. query_states = self.q_proj(hidden_states) * self.scaling
  261. query_states = query_states.view(*q_input_shape)
  262. query_states = query_states.transpose(1, 2).contiguous()
  263. # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
  264. if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache):
  265. is_updated = past_key_values.is_updated.get(self.layer_idx)
  266. if is_cross_attention:
  267. # after the first generated id, we can subsequently re-use all key/value_states from cache
  268. past_key_values.is_updated[self.layer_idx] = True
  269. past_key_values = past_key_values.cross_attention_cache
  270. else:
  271. past_key_values = past_key_values.self_attention_cache
  272. # use key_value_states if cross attention
  273. current_states = key_value_states if key_value_states is not None else hidden_states
  274. if is_cross_attention and past_key_values and is_updated:
  275. # reuse k,v, cross_attentions
  276. key_states = past_key_values.layers[self.layer_idx].keys
  277. value_states = past_key_values.layers[self.layer_idx].values
  278. else:
  279. key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
  280. value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
  281. key_states = key_states.transpose(1, 2).contiguous()
  282. value_states = value_states.transpose(1, 2).contiguous()
  283. if past_key_values is not None:
  284. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  285. cache_position = cache_position if not is_cross_attention else None
  286. key_states, value_states = past_key_values.update(
  287. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  288. )
  289. attention_interface: Callable = eager_attention_forward
  290. if self.config._attn_implementation != "eager":
  291. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  292. attn_output, attn_weights = attention_interface(
  293. self,
  294. query_states,
  295. key_states,
  296. value_states,
  297. attention_mask,
  298. dropout=0.0 if not self.training else self.dropout,
  299. scaling=1.0,
  300. output_attentions=output_attentions,
  301. head_mask=layer_head_mask,
  302. **kwargs,
  303. )
  304. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  305. attn_output = self.out_proj(attn_output)
  306. return attn_output, attn_weights
  307. # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER
  308. class WhisperEncoderLayer(GradientCheckpointingLayer):
  309. def __init__(self, config: WhisperConfig):
  310. super().__init__()
  311. self.embed_dim = config.d_model
  312. self.self_attn = WhisperAttention(
  313. embed_dim=self.embed_dim,
  314. num_heads=config.encoder_attention_heads,
  315. dropout=config.attention_dropout,
  316. config=config,
  317. )
  318. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  319. self.dropout = config.dropout
  320. self.activation_fn = ACT2FN[config.activation_function]
  321. self.activation_dropout = config.activation_dropout
  322. self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
  323. self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
  324. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  325. def forward(
  326. self,
  327. hidden_states: torch.Tensor,
  328. attention_mask: torch.Tensor,
  329. layer_head_mask: torch.Tensor,
  330. output_attentions: bool = False,
  331. ) -> torch.Tensor:
  332. """
  333. Args:
  334. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  335. attention_mask (`torch.FloatTensor`): attention mask of size
  336. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  337. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  338. `(encoder_attention_heads,)`.
  339. output_attentions (`bool`, *optional*):
  340. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  341. returned tensors for more detail.
  342. """
  343. residual = hidden_states
  344. hidden_states = self.self_attn_layer_norm(hidden_states)
  345. hidden_states, attn_weights = self.self_attn(
  346. hidden_states=hidden_states,
  347. attention_mask=attention_mask,
  348. layer_head_mask=layer_head_mask,
  349. output_attentions=output_attentions,
  350. )
  351. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  352. hidden_states = residual + hidden_states
  353. residual = hidden_states
  354. hidden_states = self.final_layer_norm(hidden_states)
  355. hidden_states = self.activation_fn(self.fc1(hidden_states))
  356. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  357. hidden_states = self.fc2(hidden_states)
  358. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  359. hidden_states = residual + hidden_states
  360. if hidden_states.dtype == torch.float16:
  361. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  362. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  363. return hidden_states, attn_weights
  364. class WhisperDecoderLayer(GradientCheckpointingLayer):
  365. def __init__(self, config: WhisperConfig, layer_idx: Optional[int] = None):
  366. super().__init__()
  367. self.embed_dim = config.d_model
  368. self.self_attn = WhisperAttention(
  369. embed_dim=self.embed_dim,
  370. num_heads=config.decoder_attention_heads,
  371. dropout=config.attention_dropout,
  372. is_decoder=True,
  373. is_causal=True,
  374. layer_idx=layer_idx,
  375. config=config,
  376. )
  377. self.dropout = config.dropout
  378. self.activation_fn = ACT2FN[config.activation_function]
  379. self.activation_dropout = config.activation_dropout
  380. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  381. self.encoder_attn = WhisperAttention(
  382. self.embed_dim,
  383. config.decoder_attention_heads,
  384. dropout=config.attention_dropout,
  385. is_decoder=True,
  386. layer_idx=layer_idx,
  387. config=config,
  388. )
  389. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
  390. self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
  391. self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
  392. self.final_layer_norm = nn.LayerNorm(self.embed_dim)
  393. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  394. def forward(
  395. self,
  396. hidden_states: torch.Tensor,
  397. attention_mask: Optional[torch.Tensor] = None,
  398. encoder_hidden_states: Optional[torch.Tensor] = None,
  399. encoder_attention_mask: Optional[torch.Tensor] = None,
  400. layer_head_mask: Optional[torch.Tensor] = None,
  401. cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
  402. past_key_values: Optional[EncoderDecoderCache] = None,
  403. output_attentions: Optional[bool] = False,
  404. use_cache: Optional[bool] = True,
  405. cache_position: Optional[torch.LongTensor] = None,
  406. ) -> torch.Tensor:
  407. """
  408. Args:
  409. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  410. attention_mask (`torch.FloatTensor`): attention mask of size
  411. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  412. encoder_hidden_states (`torch.FloatTensor`):
  413. cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
  414. encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
  415. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  416. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
  417. `(encoder_attention_heads,)`.
  418. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
  419. size `(decoder_attention_heads,)`.
  420. past_key_values (`Cache`): cached past key and value projection states
  421. output_attentions (`bool`, *optional*):
  422. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  423. returned tensors for more detail.
  424. """
  425. residual = hidden_states
  426. hidden_states = self.self_attn_layer_norm(hidden_states)
  427. # Self Attention
  428. hidden_states, self_attn_weights = self.self_attn(
  429. hidden_states=hidden_states,
  430. past_key_values=past_key_values,
  431. attention_mask=attention_mask,
  432. layer_head_mask=layer_head_mask,
  433. output_attentions=output_attentions,
  434. cache_position=cache_position,
  435. )
  436. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  437. hidden_states = residual + hidden_states
  438. # Cross-Attention Block
  439. cross_attn_weights = None
  440. if encoder_hidden_states is not None:
  441. residual = hidden_states
  442. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  443. hidden_states, cross_attn_weights = self.encoder_attn(
  444. hidden_states=hidden_states,
  445. key_value_states=encoder_hidden_states,
  446. attention_mask=encoder_attention_mask,
  447. layer_head_mask=cross_attn_layer_head_mask,
  448. past_key_values=past_key_values,
  449. output_attentions=output_attentions,
  450. )
  451. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  452. hidden_states = residual + hidden_states
  453. # Fully Connected
  454. residual = hidden_states
  455. hidden_states = self.final_layer_norm(hidden_states)
  456. hidden_states = self.activation_fn(self.fc1(hidden_states))
  457. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  458. hidden_states = self.fc2(hidden_states)
  459. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  460. hidden_states = residual + hidden_states
  461. outputs = (hidden_states,)
  462. if output_attentions:
  463. outputs += (self_attn_weights, cross_attn_weights)
  464. return outputs
  465. @auto_docstring
  466. class WhisperPreTrainedModel(PreTrainedModel):
  467. config: WhisperConfig
  468. base_model_prefix = "model"
  469. main_input_name = "input_features"
  470. supports_gradient_checkpointing = True
  471. _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
  472. _supports_flash_attn = True
  473. _supports_sdpa = True
  474. _supports_flex_attn = True
  475. _can_compile_fullgraph = True
  476. def _init_weights(self, module):
  477. std = self.config.init_std
  478. if isinstance(module, (nn.Linear, nn.Conv1d)):
  479. module.weight.data.normal_(mean=0.0, std=std)
  480. if module.bias is not None:
  481. module.bias.data.zero_()
  482. elif isinstance(module, nn.Embedding):
  483. module.weight.data.normal_(mean=0.0, std=std)
  484. if module.padding_idx is not None:
  485. module.weight.data[module.padding_idx].zero_()
  486. elif isinstance(module, nn.LayerNorm):
  487. module.weight.data.fill_(1.0)
  488. module.bias.data.zero_()
  489. elif isinstance(module, WhisperEncoder):
  490. module.embed_positions.weight.copy_(sinusoids(*module.embed_positions.weight.shape))
  491. elif isinstance(module, WhisperForAudioClassification):
  492. if self.config.use_weighted_layer_sum:
  493. module.layer_weights.data.fill_(1.0 / (self.config.num_hidden_layers + 1))
  494. def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
  495. """
  496. Computes the output length of the convolutional layers
  497. """
  498. input_lengths = (input_lengths - 1) // 2 + 1
  499. return input_lengths
  500. class WhisperEncoder(WhisperPreTrainedModel):
  501. """
  502. Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
  503. [`WhisperEncoderLayer`].
  504. Args:
  505. config: WhisperConfig
  506. """
  507. def __init__(self, config: WhisperConfig):
  508. super().__init__(config)
  509. self.dropout = config.dropout
  510. self.layerdrop = config.encoder_layerdrop
  511. embed_dim = config.d_model
  512. self.num_mel_bins = config.num_mel_bins
  513. self.padding_idx = config.pad_token_id
  514. self.max_source_positions = config.max_source_positions
  515. self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
  516. self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
  517. self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
  518. self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
  519. self.embed_positions.requires_grad_(False)
  520. self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
  521. self.layer_norm = nn.LayerNorm(config.d_model)
  522. self.gradient_checkpointing = False
  523. # Initialize weights and apply final processing
  524. self.post_init()
  525. def _freeze_parameters(self):
  526. for param in self.parameters():
  527. param.requires_grad = False
  528. self._requires_grad = False
  529. def get_input_embeddings(self) -> nn.Module:
  530. return self.conv1
  531. def set_input_embeddings(self, value: nn.Module):
  532. self.conv1 = value
  533. def forward(
  534. self,
  535. input_features,
  536. attention_mask=None,
  537. head_mask=None,
  538. output_attentions=None,
  539. output_hidden_states=None,
  540. return_dict=None,
  541. ):
  542. r"""
  543. Args:
  544. input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
  545. Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
  546. obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
  547. `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
  548. the soundfile library (`pip install soundfile`). To prepare the array into
  549. `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
  550. and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
  551. attention_mask (`torch.Tensor`)`, *optional*):
  552. Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
  553. but it is not used. By default the silence in the input log mel spectrogram are ignored.
  554. head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
  555. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  556. - 1 indicates the head is **not masked**,
  557. - 0 indicates the head is **masked**.
  558. output_attentions (`bool`, *optional*):
  559. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  560. returned tensors for more detail.
  561. output_hidden_states (`bool`, *optional*):
  562. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  563. for more detail.
  564. return_dict (`bool`, *optional*):
  565. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  566. """
  567. expected_seq_length = self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0]
  568. if input_features.shape[-1] != expected_seq_length:
  569. raise ValueError(
  570. f"Whisper expects the mel input features to be of length {expected_seq_length}, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
  571. )
  572. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  573. output_hidden_states = (
  574. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  575. )
  576. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  577. inputs_embeds = nn.functional.gelu(self.conv1(input_features))
  578. inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
  579. inputs_embeds = inputs_embeds.permute(0, 2, 1)
  580. all_positions = torch.arange(self.embed_positions.num_embeddings, device=inputs_embeds.device)
  581. hidden_states = inputs_embeds + self.embed_positions(all_positions)
  582. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  583. encoder_states = () if output_hidden_states else None
  584. all_attentions = () if output_attentions else None
  585. # check if head_mask has a correct number of layers specified if desired
  586. if head_mask is not None:
  587. assert head_mask.size()[0] == (len(self.layers)), (
  588. f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
  589. )
  590. for idx, encoder_layer in enumerate(self.layers):
  591. if output_hidden_states:
  592. encoder_states = encoder_states + (hidden_states,)
  593. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  594. to_drop = False
  595. if self.training:
  596. dropout_probability = torch.rand([])
  597. if dropout_probability < self.layerdrop: # skip the layer
  598. to_drop = True
  599. if to_drop:
  600. layer_outputs = (None, None)
  601. else:
  602. layer_outputs = encoder_layer(
  603. hidden_states,
  604. None,
  605. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  606. output_attentions=output_attentions,
  607. )
  608. hidden_states = layer_outputs[0]
  609. if output_attentions:
  610. all_attentions = all_attentions + (layer_outputs[1],)
  611. hidden_states = self.layer_norm(hidden_states)
  612. if output_hidden_states:
  613. encoder_states = encoder_states + (hidden_states,)
  614. if not return_dict:
  615. return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
  616. return BaseModelOutput(
  617. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  618. )
  619. class WhisperDecoder(WhisperPreTrainedModel):
  620. """
  621. Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`WhisperDecoderLayer`]
  622. Args:
  623. config: WhisperConfig
  624. """
  625. main_input_name = "input_ids"
  626. def __init__(self, config: WhisperConfig):
  627. super().__init__(config)
  628. self.dropout = config.dropout
  629. self.layerdrop = config.decoder_layerdrop
  630. self.padding_idx = config.pad_token_id
  631. self.max_target_positions = config.max_target_positions
  632. self.max_source_positions = config.max_source_positions
  633. self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
  634. self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
  635. self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model)
  636. self.layers = nn.ModuleList(
  637. [WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)]
  638. )
  639. self.layer_norm = nn.LayerNorm(config.d_model)
  640. self.gradient_checkpointing = False
  641. # Initialize weights and apply final processing
  642. self.post_init()
  643. def forward(
  644. self,
  645. input_ids=None,
  646. attention_mask=None,
  647. encoder_hidden_states=None,
  648. head_mask=None,
  649. cross_attn_head_mask=None,
  650. past_key_values=None,
  651. inputs_embeds=None,
  652. position_ids=None,
  653. use_cache=None,
  654. output_attentions=None,
  655. output_hidden_states=None,
  656. return_dict=None,
  657. cache_position=None,
  658. ):
  659. r"""
  660. Args:
  661. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  662. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
  663. provide it.
  664. Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  665. [`PreTrainedTokenizer.__call__`] for details.
  666. [What are input IDs?](../glossary#input-ids)
  667. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  668. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  669. - 1 for tokens that are **not masked**,
  670. - 0 for tokens that are **masked**.
  671. [What are attention masks?](../glossary#attention-mask)
  672. encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
  673. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  674. of the decoder.
  675. head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  676. Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
  677. - 1 indicates the head is **not masked**,
  678. - 0 indicates the head is **masked**.
  679. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  680. Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
  681. on hidden heads. Mask values selected in `[0, 1]`:
  682. - 1 indicates the head is **not masked**,
  683. - 0 indicates the head is **masked**.
  684. past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
  685. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  686. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
  687. that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
  688. all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  689. inputs_embeds (`torch.FloatTensor` of
  690. shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
  691. `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
  692. control over how to convert `input_ids` indices into associated vectors than the model's internal
  693. embedding lookup matrix.
  694. output_attentions (`bool`, *optional*):
  695. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  696. returned tensors for more detail.
  697. output_hidden_states (`bool`, *optional*):
  698. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  699. for more detail.
  700. return_dict (`bool`, *optional*):
  701. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  702. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  703. Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
  704. cache in the correct position and to infer the complete sequence length.
  705. """
  706. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  707. output_hidden_states = (
  708. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  709. )
  710. use_cache = use_cache if use_cache is not None else self.config.use_cache
  711. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  712. # retrieve input_ids and inputs_embeds
  713. if input_ids is not None and inputs_embeds is not None:
  714. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  715. elif input_ids is not None:
  716. input_shape = input_ids.size()
  717. input_ids = input_ids.view(-1, input_shape[-1])
  718. elif inputs_embeds is not None:
  719. input_shape = inputs_embeds.size()[:-1]
  720. else:
  721. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  722. if inputs_embeds is None:
  723. inputs_embeds = self.embed_tokens(input_ids)
  724. if use_cache and past_key_values is None:
  725. if self.config.is_encoder_decoder:
  726. past_key_values = EncoderDecoderCache(
  727. DynamicCache(config=self.config), DynamicCache(config=self.config)
  728. )
  729. else:
  730. past_key_values = DynamicCache(config=self.config)
  731. past_key_values_length = 0
  732. if cache_position is not None:
  733. past_key_values_length = cache_position[0]
  734. elif past_key_values is not None:
  735. past_key_values_length = past_key_values.get_seq_length()
  736. if cache_position is None:
  737. cache_position = torch.arange(
  738. past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device
  739. )
  740. if position_ids is None:
  741. position_ids = cache_position.unsqueeze(0).repeat(input_shape[0], 1)
  742. # embed positions
  743. if input_ids is not None:
  744. positions = self.embed_positions(
  745. input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
  746. )
  747. else:
  748. positions = self.embed_positions(
  749. inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
  750. )
  751. hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
  752. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  753. causal_mask = create_causal_mask(
  754. config=self.config,
  755. input_embeds=inputs_embeds,
  756. attention_mask=attention_mask,
  757. cache_position=cache_position,
  758. past_key_values=past_key_values,
  759. position_ids=position_ids,
  760. )
  761. if self.gradient_checkpointing and self.training:
  762. if use_cache:
  763. logger.warning_once(
  764. "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
  765. )
  766. use_cache = False
  767. # decoder layers
  768. all_hidden_states = () if output_hidden_states else None
  769. all_self_attns = () if output_attentions else None
  770. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  771. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  772. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  773. if attn_mask is not None:
  774. assert attn_mask.size()[0] == (len(self.layers)), (
  775. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  776. f" {head_mask.size()[0]}."
  777. )
  778. for idx, decoder_layer in enumerate(self.layers):
  779. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  780. if output_hidden_states:
  781. all_hidden_states += (hidden_states,)
  782. if self.training:
  783. dropout_probability = torch.rand([])
  784. if dropout_probability < self.layerdrop:
  785. continue
  786. layer_outputs = decoder_layer(
  787. hidden_states,
  788. attention_mask=causal_mask,
  789. encoder_hidden_states=encoder_hidden_states,
  790. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  791. cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
  792. past_key_values=past_key_values if use_cache else None,
  793. output_attentions=output_attentions,
  794. use_cache=use_cache,
  795. cache_position=cache_position,
  796. )
  797. hidden_states = layer_outputs[0]
  798. if output_attentions:
  799. all_self_attns += (layer_outputs[1],)
  800. if encoder_hidden_states is not None:
  801. all_cross_attentions += (layer_outputs[2],)
  802. hidden_states = self.layer_norm(hidden_states)
  803. # add hidden states from the last decoder layer
  804. if output_hidden_states:
  805. all_hidden_states += (hidden_states,)
  806. next_cache = past_key_values if use_cache else None
  807. if not return_dict:
  808. return tuple(
  809. v
  810. for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
  811. if v is not None
  812. )
  813. return BaseModelOutputWithPastAndCrossAttentions(
  814. last_hidden_state=hidden_states,
  815. past_key_values=next_cache,
  816. hidden_states=all_hidden_states,
  817. attentions=all_self_attns,
  818. cross_attentions=all_cross_attentions,
  819. )
  820. @auto_docstring
  821. class WhisperModel(WhisperPreTrainedModel):
  822. def __init__(self, config: WhisperConfig):
  823. super().__init__(config)
  824. self.encoder = WhisperEncoder(config)
  825. self.decoder = WhisperDecoder(config)
  826. # Initialize weights and apply final processing
  827. self.post_init()
  828. def get_input_embeddings(self):
  829. return self.decoder.embed_tokens
  830. def set_input_embeddings(self, value):
  831. self.decoder.embed_tokens = value
  832. def get_encoder(self):
  833. return self.encoder
  834. def freeze_encoder(self):
  835. """
  836. Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
  837. not be updated during training.
  838. """
  839. self.encoder._freeze_parameters()
  840. def _mask_input_features(
  841. self,
  842. input_features: torch.FloatTensor,
  843. attention_mask: Optional[torch.LongTensor] = None,
  844. ):
  845. """
  846. Masks extracted features along time axis and/or along feature axis according to
  847. [SpecAugment](https://huggingface.co/papers/1904.08779).
  848. """
  849. # `config.apply_spec_augment` can set masking to False
  850. if not getattr(self.config, "apply_spec_augment", True):
  851. return input_features
  852. # generate indices & apply SpecAugment along time axis
  853. batch_size, hidden_size, sequence_length = input_features.size()
  854. if self.config.mask_time_prob > 0 and self.training:
  855. # generate indices & apply SpecAugment along time axis
  856. mask_time_indices = _compute_mask_indices(
  857. (batch_size, sequence_length),
  858. mask_prob=self.config.mask_time_prob,
  859. mask_length=self.config.mask_time_length,
  860. attention_mask=attention_mask,
  861. min_masks=self.config.mask_time_min_masks,
  862. )
  863. mask_time_indices = torch.tensor(mask_time_indices, device=input_features.device, dtype=torch.bool)
  864. mask_time_indices = mask_time_indices[:, None].expand(-1, hidden_size, -1)
  865. input_features[mask_time_indices] = 0
  866. if self.config.mask_feature_prob > 0 and self.training:
  867. # generate indices & apply SpecAugment along feature axis
  868. mask_feature_indices = _compute_mask_indices(
  869. (batch_size, hidden_size),
  870. mask_prob=self.config.mask_feature_prob,
  871. mask_length=self.config.mask_feature_length,
  872. min_masks=self.config.mask_feature_min_masks,
  873. )
  874. mask_feature_indices = torch.tensor(mask_feature_indices, device=input_features.device, dtype=torch.bool)
  875. input_features[mask_feature_indices] = 0
  876. return input_features
  877. @auto_docstring
  878. def forward(
  879. self,
  880. input_features: Optional[torch.FloatTensor] = None,
  881. attention_mask: Optional[torch.LongTensor] = None,
  882. decoder_input_ids: Optional[torch.LongTensor] = None,
  883. decoder_attention_mask: Optional[torch.LongTensor] = None,
  884. head_mask: Optional[torch.Tensor] = None,
  885. decoder_head_mask: Optional[torch.Tensor] = None,
  886. cross_attn_head_mask: Optional[torch.Tensor] = None,
  887. encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
  888. past_key_values: Optional[Cache] = None,
  889. decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None,
  890. decoder_position_ids: Optional[tuple[torch.LongTensor]] = None,
  891. use_cache: Optional[bool] = None,
  892. output_attentions: Optional[bool] = None,
  893. output_hidden_states: Optional[bool] = None,
  894. return_dict: Optional[bool] = None,
  895. cache_position: Optional[torch.LongTensor] = None,
  896. ) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
  897. r"""
  898. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  899. Indices of decoder input sequence tokens in the vocabulary.
  900. Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  901. [`PreTrainedTokenizer.__call__`] for details.
  902. [What are decoder input IDs?](../glossary#decoder-input-ids)
  903. Whisper uses the `decoder_start_token_id` as the starting token for `decoder_input_ids` generation. If
  904. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  905. `past_key_values`).
  906. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  907. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  908. be used by default.
  909. If you want to change padding behavior, you should read
  910. [`modeling_whisper._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the BART
  911. paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy.
  912. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  913. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  914. - 1 indicates the head is **not masked**,
  915. - 0 indicates the head is **masked**.
  916. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  917. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  918. config.n_positions - 1]`.
  919. [What are position IDs?](../glossary#position-ids)
  920. Example:
  921. ```python
  922. >>> import torch
  923. >>> from transformers import AutoFeatureExtractor, WhisperModel
  924. >>> from datasets import load_dataset
  925. >>> model = WhisperModel.from_pretrained("openai/whisper-base")
  926. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
  927. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  928. >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
  929. >>> input_features = inputs.input_features
  930. >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
  931. >>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
  932. >>> list(last_hidden_state.shape)
  933. [1, 2, 512]
  934. ```"""
  935. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  936. output_hidden_states = (
  937. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  938. )
  939. use_cache = use_cache if use_cache is not None else self.config.use_cache
  940. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  941. if encoder_outputs is None:
  942. input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
  943. encoder_outputs = self.encoder(
  944. input_features,
  945. head_mask=head_mask,
  946. output_attentions=output_attentions,
  947. output_hidden_states=output_hidden_states,
  948. return_dict=return_dict,
  949. )
  950. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
  951. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  952. encoder_outputs = BaseModelOutput(
  953. last_hidden_state=encoder_outputs[0],
  954. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  955. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  956. )
  957. # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn)
  958. decoder_outputs = self.decoder(
  959. input_ids=decoder_input_ids,
  960. attention_mask=decoder_attention_mask,
  961. encoder_hidden_states=encoder_outputs[0],
  962. head_mask=decoder_head_mask,
  963. cross_attn_head_mask=cross_attn_head_mask,
  964. past_key_values=past_key_values,
  965. inputs_embeds=decoder_inputs_embeds,
  966. position_ids=decoder_position_ids,
  967. use_cache=use_cache,
  968. output_attentions=output_attentions,
  969. output_hidden_states=output_hidden_states,
  970. return_dict=return_dict,
  971. cache_position=cache_position,
  972. )
  973. if not return_dict:
  974. return decoder_outputs + encoder_outputs
  975. return Seq2SeqModelOutput(
  976. last_hidden_state=decoder_outputs.last_hidden_state,
  977. past_key_values=decoder_outputs.past_key_values,
  978. decoder_hidden_states=decoder_outputs.hidden_states,
  979. decoder_attentions=decoder_outputs.attentions,
  980. cross_attentions=decoder_outputs.cross_attentions,
  981. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  982. encoder_hidden_states=encoder_outputs.hidden_states,
  983. encoder_attentions=encoder_outputs.attentions,
  984. )
  985. @auto_docstring(
  986. custom_intro="""
  987. The Whisper Model with a language modeling head. Can be used for automatic speech recognition.
  988. """
  989. )
  990. class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedModel):
  991. base_model_prefix = "model"
  992. _tied_weights_keys = ["proj_out.weight"]
  993. def __init__(self, config: WhisperConfig):
  994. super().__init__(config)
  995. self.model = WhisperModel(config)
  996. self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
  997. self.max_target_positions = config.max_target_positions
  998. # Initialize weights and apply final processing
  999. self.post_init()
  1000. def get_encoder(self):
  1001. return self.model.get_encoder()
  1002. def get_decoder(self):
  1003. return self.model.get_decoder()
  1004. def get_output_embeddings(self):
  1005. return self.proj_out
  1006. def set_output_embeddings(self, new_embeddings):
  1007. self.proj_out = new_embeddings
  1008. def get_input_embeddings(self) -> nn.Module:
  1009. return self.model.get_input_embeddings()
  1010. def freeze_encoder(self):
  1011. """
  1012. Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
  1013. not be updated during training.
  1014. """
  1015. self.model.encoder._freeze_parameters()
  1016. @auto_docstring
  1017. def forward(
  1018. self,
  1019. input_features: Optional[torch.FloatTensor] = None,
  1020. attention_mask: Optional[torch.LongTensor] = None,
  1021. decoder_input_ids: Optional[torch.LongTensor] = None,
  1022. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1023. head_mask: Optional[torch.Tensor] = None,
  1024. decoder_head_mask: Optional[torch.Tensor] = None,
  1025. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1026. encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
  1027. past_key_values: Optional[Cache] = None,
  1028. decoder_inputs_embeds: Optional[tuple[torch.FloatTensor]] = None,
  1029. decoder_position_ids: Optional[tuple[torch.LongTensor]] = None,
  1030. labels: Optional[torch.LongTensor] = None,
  1031. use_cache: Optional[bool] = None,
  1032. output_attentions: Optional[bool] = None,
  1033. output_hidden_states: Optional[bool] = None,
  1034. return_dict: Optional[bool] = None,
  1035. cache_position: Optional[torch.LongTensor] = None,
  1036. ) -> Union[tuple[torch.Tensor], Seq2SeqLMOutput]:
  1037. r"""
  1038. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1039. Indices of decoder input sequence tokens in the vocabulary.
  1040. Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1041. [`PreTrainedTokenizer.__call__`] for details.
  1042. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1043. Whisper uses the `decoder_start_token_id` as the starting token for `decoder_input_ids` generation. If
  1044. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1045. `past_key_values`).
  1046. decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1047. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1048. be used by default.
  1049. If you want to change padding behavior, you should read
  1050. [`modeling_whisper._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the BART
  1051. paper](https://huggingface.co/papers/1910.13461) for more information on the default strategy.
  1052. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1053. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1054. - 1 indicates the head is **not masked**,
  1055. - 0 indicates the head is **masked**.
  1056. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1057. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1058. config.n_positions - 1]`.
  1059. [What are position IDs?](../glossary#position-ids)
  1060. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1061. Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
  1062. or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
  1063. only computed for the tokens with labels in `[0, ..., config.vocab_size]`. `sequence_length` should be smaller than or equal to `config.max_target_positions`.
  1064. Example:
  1065. ```python
  1066. >>> import torch
  1067. >>> from transformers import AutoProcessor, WhisperForConditionalGeneration
  1068. >>> from datasets import load_dataset
  1069. >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
  1070. >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
  1071. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  1072. >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
  1073. >>> input_features = inputs.input_features
  1074. >>> generated_ids = model.generate(inputs=input_features)
  1075. >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  1076. >>> transcription
  1077. ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
  1078. ```"""
  1079. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1080. if labels is not None:
  1081. if labels.shape[1] > self.max_target_positions:
  1082. raise ValueError(
  1083. f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {self.max_target_positions} tokens."
  1084. )
  1085. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1086. decoder_input_ids = shift_tokens_right(
  1087. labels, self.config.pad_token_id, self.config.decoder_start_token_id
  1088. )
  1089. outputs = self.model(
  1090. input_features,
  1091. attention_mask=attention_mask,
  1092. decoder_input_ids=decoder_input_ids,
  1093. encoder_outputs=encoder_outputs,
  1094. decoder_attention_mask=decoder_attention_mask,
  1095. head_mask=head_mask,
  1096. decoder_head_mask=decoder_head_mask,
  1097. cross_attn_head_mask=cross_attn_head_mask,
  1098. past_key_values=past_key_values,
  1099. decoder_inputs_embeds=decoder_inputs_embeds,
  1100. decoder_position_ids=decoder_position_ids,
  1101. use_cache=use_cache,
  1102. output_attentions=output_attentions,
  1103. output_hidden_states=output_hidden_states,
  1104. return_dict=return_dict,
  1105. cache_position=cache_position,
  1106. )
  1107. lm_logits = self.proj_out(outputs[0])
  1108. loss = None
  1109. if labels is not None:
  1110. loss_fct = CrossEntropyLoss()
  1111. # move labels to correct device to enable PP
  1112. labels = labels.to(lm_logits.device)
  1113. loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
  1114. if not return_dict:
  1115. output = (lm_logits,) + outputs[1:]
  1116. return ((loss,) + output) if loss is not None else output
  1117. return Seq2SeqLMOutput(
  1118. loss=loss,
  1119. logits=lm_logits,
  1120. past_key_values=outputs.past_key_values,
  1121. decoder_hidden_states=outputs.decoder_hidden_states,
  1122. decoder_attentions=outputs.decoder_attentions,
  1123. cross_attentions=outputs.cross_attentions,
  1124. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1125. encoder_hidden_states=outputs.encoder_hidden_states,
  1126. encoder_attentions=outputs.encoder_attentions,
  1127. )
  1128. class WhisperDecoderWrapper(WhisperPreTrainedModel):
  1129. """
  1130. This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
  1131. used in combination with the [`EncoderDecoderModel`] framework.
  1132. """
  1133. def __init__(self, config):
  1134. super().__init__(config)
  1135. config.is_encoder_decoder = False
  1136. self.decoder = WhisperDecoder(config)
  1137. def get_input_embeddings(self):
  1138. return self.decoder.embed_tokens
  1139. def set_input_embeddings(self, value):
  1140. self.decoder.embed_tokens = value
  1141. def forward(self, *args, **kwargs):
  1142. return self.decoder(*args, **kwargs)
  1143. @auto_docstring(
  1144. custom_intro="""
  1145. Whisper decoder with a language modeling head on top (linear layer with weights tied to the input embeddings).
  1146. """
  1147. )
  1148. class WhisperForCausalLM(WhisperPreTrainedModel, GenerationMixin):
  1149. _tied_weights_keys = ["proj_out.weight"]
  1150. main_input_name = "input_ids"
  1151. def __init__(self, config):
  1152. super().__init__(config)
  1153. config.is_encoder_decoder = False
  1154. self.model = WhisperDecoderWrapper(config)
  1155. self.proj_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1156. # Initialize weights and apply final processing
  1157. self.post_init()
  1158. def get_output_embeddings(self):
  1159. return self.proj_out
  1160. def set_output_embeddings(self, new_embeddings):
  1161. self.proj_out = new_embeddings
  1162. def get_input_embeddings(self) -> nn.Module:
  1163. return self.model.get_input_embeddings()
  1164. def set_input_embeddings(self, value):
  1165. self.model.set_input_embeddings(value)
  1166. def set_decoder(self, decoder):
  1167. self.model.decoder = decoder
  1168. def get_decoder(self):
  1169. return self.model.decoder
  1170. @auto_docstring
  1171. def forward(
  1172. self,
  1173. input_ids: Optional[torch.LongTensor] = None,
  1174. attention_mask: Optional[torch.Tensor] = None,
  1175. encoder_outputs: Optional[tuple[torch.FloatTensor]] = None,
  1176. head_mask: Optional[torch.Tensor] = None,
  1177. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1178. past_key_values: Optional[Cache] = None,
  1179. inputs_embeds: Optional[torch.FloatTensor] = None,
  1180. labels: Optional[torch.LongTensor] = None,
  1181. use_cache: Optional[bool] = None,
  1182. output_attentions: Optional[bool] = None,
  1183. output_hidden_states: Optional[bool] = None,
  1184. return_dict: Optional[bool] = None,
  1185. cache_position: Optional[torch.LongTensor] = None,
  1186. ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
  1187. r"""
  1188. encoder_outputs (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  1189. Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  1190. if the model is configured as a decoder.
  1191. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1192. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1193. - 1 indicates the head is **not masked**,
  1194. - 0 indicates the head is **masked**.
  1195. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1196. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1197. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1198. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1199. Example:
  1200. ```python
  1201. >>> from transformers import WhisperForCausalLM, WhisperForConditionalGeneration, WhisperProcessor
  1202. >>> import torch
  1203. >>> from datasets import load_dataset
  1204. >>> processor = WhisperProcessor.from_pretrained("openai/whisper-large-v2")
  1205. >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v2")
  1206. >>> assistant_model = WhisperForCausalLM.from_pretrained("distil-whisper/distil-large-v2")
  1207. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  1208. >>> sample = ds[0]["audio"]
  1209. >>> input_features = processor(
  1210. ... sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
  1211. ... ).input_features
  1212. >>> predicted_ids = model.generate(input_features, assistant_model=assistant_model)
  1213. >>> # decode token ids to text
  1214. >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
  1215. >>> transcription
  1216. ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.'
  1217. ```"""
  1218. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1219. output_hidden_states = (
  1220. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1221. )
  1222. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1223. # If the user passed a tuple or `BaseModelOutput` for encoder_outputs, we extract only the hidden states
  1224. if isinstance(encoder_outputs, (BaseModelOutput, tuple, list)):
  1225. encoder_outputs = encoder_outputs[0]
  1226. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  1227. outputs = self.model.decoder(
  1228. input_ids=input_ids,
  1229. attention_mask=attention_mask,
  1230. encoder_hidden_states=encoder_outputs,
  1231. head_mask=head_mask,
  1232. cross_attn_head_mask=cross_attn_head_mask,
  1233. past_key_values=past_key_values,
  1234. inputs_embeds=inputs_embeds,
  1235. use_cache=use_cache,
  1236. output_attentions=output_attentions,
  1237. output_hidden_states=output_hidden_states,
  1238. return_dict=return_dict,
  1239. cache_position=cache_position,
  1240. )
  1241. logits = self.proj_out(outputs[0])
  1242. loss = None
  1243. if labels is not None:
  1244. labels = labels.to(logits.device)
  1245. loss_fct = CrossEntropyLoss()
  1246. loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
  1247. if not return_dict:
  1248. output = (logits,) + outputs[1:]
  1249. return (loss,) + output if loss is not None else output
  1250. return CausalLMOutputWithCrossAttentions(
  1251. loss=loss,
  1252. logits=logits,
  1253. past_key_values=outputs.past_key_values,
  1254. hidden_states=outputs.hidden_states,
  1255. attentions=outputs.attentions,
  1256. cross_attentions=outputs.cross_attentions,
  1257. )
  1258. @auto_docstring(
  1259. custom_intro="""
  1260. Whisper Encoder Model with a sequence classification head on top (a linear layer over the pooled output) for tasks
  1261. like SUPERB Keyword Spotting.
  1262. """
  1263. )
  1264. class WhisperForAudioClassification(WhisperPreTrainedModel):
  1265. def __init__(self, config):
  1266. super().__init__(config)
  1267. self.encoder = WhisperEncoder(config)
  1268. num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
  1269. if config.use_weighted_layer_sum:
  1270. self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
  1271. self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
  1272. self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
  1273. # Initialize weights and apply final processing
  1274. self.post_init()
  1275. def freeze_encoder(self):
  1276. """
  1277. Calling this function will disable the gradient computation for the Whisper encoder so that its parameters will
  1278. not be updated during training. Only the projection layers and classification head will be updated.
  1279. """
  1280. self.encoder._freeze_parameters()
  1281. def get_input_embeddings(self) -> nn.Module:
  1282. return self.encoder.get_input_embeddings()
  1283. def set_input_embeddings(self, value: nn.Module):
  1284. self.encoder.set_input_embeddings(value)
  1285. @auto_docstring
  1286. def forward(
  1287. self,
  1288. input_features: Optional[torch.LongTensor] = None,
  1289. head_mask: Optional[torch.Tensor] = None,
  1290. encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
  1291. labels: Optional[torch.LongTensor] = None,
  1292. output_attentions: Optional[bool] = None,
  1293. output_hidden_states: Optional[bool] = None,
  1294. return_dict: Optional[bool] = None,
  1295. ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
  1296. r"""
  1297. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1298. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1299. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1300. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1301. Example:
  1302. ```python
  1303. >>> import torch
  1304. >>> from transformers import AutoFeatureExtractor, WhisperForAudioClassification
  1305. >>> from datasets import load_dataset
  1306. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
  1307. >>> model = WhisperForAudioClassification.from_pretrained("sanchit-gandhi/whisper-medium-fleurs-lang-id")
  1308. >>> ds = load_dataset("google/fleurs", "all", split="validation", streaming=True)
  1309. >>> sample = next(iter(ds))
  1310. >>> inputs = feature_extractor(
  1311. ... sample["audio"]["array"], sampling_rate=sample["audio"]["sampling_rate"], return_tensors="pt"
  1312. ... )
  1313. >>> input_features = inputs.input_features
  1314. >>> with torch.no_grad():
  1315. ... logits = model(input_features).logits
  1316. >>> predicted_class_ids = torch.argmax(logits).item()
  1317. >>> predicted_label = model.config.id2label[predicted_class_ids]
  1318. >>> predicted_label
  1319. 'Afrikaans'
  1320. ```"""
  1321. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1322. output_hidden_states = (
  1323. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1324. )
  1325. if self.config.use_weighted_layer_sum:
  1326. output_hidden_states = True
  1327. elif output_hidden_states is None:
  1328. output_hidden_states = self.config.output_hidden_states
  1329. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1330. if encoder_outputs is None:
  1331. encoder_outputs = self.encoder(
  1332. input_features,
  1333. head_mask=head_mask,
  1334. output_attentions=output_attentions,
  1335. output_hidden_states=output_hidden_states,
  1336. return_dict=return_dict,
  1337. )
  1338. if self.config.use_weighted_layer_sum:
  1339. hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION]
  1340. hidden_states = torch.stack(hidden_states, dim=1)
  1341. norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
  1342. hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
  1343. else:
  1344. hidden_states = encoder_outputs[0]
  1345. hidden_states = self.projector(hidden_states)
  1346. pooled_output = hidden_states.mean(dim=1)
  1347. logits = self.classifier(pooled_output)
  1348. loss = None
  1349. if labels is not None:
  1350. loss_fct = CrossEntropyLoss()
  1351. # move labels to correct device to enable PP
  1352. labels = labels.to(logits.device)
  1353. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1354. if not return_dict:
  1355. output = (logits,) + encoder_outputs[1:]
  1356. return ((loss,) + output) if loss is not None else output
  1357. return SequenceClassifierOutput(
  1358. loss=loss,
  1359. logits=logits,
  1360. hidden_states=encoder_outputs.hidden_states,
  1361. attentions=encoder_outputs.attentions,
  1362. )
  1363. __all__ = [
  1364. "WhisperForCausalLM",
  1365. "WhisperForConditionalGeneration",
  1366. "WhisperModel",
  1367. "WhisperPreTrainedModel",
  1368. "WhisperForAudioClassification",
  1369. ]