modular_parakeet.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. # coding=utf-8
  2. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Parakeet model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Callable, Optional, Union
  19. import torch
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import BaseModelOutput, CausalLMOutput
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
  27. from ...utils.generic import check_model_inputs
  28. from ..fastspeech2_conformer.modeling_fastspeech2_conformer import FastSpeech2ConformerConvolutionModule
  29. from ..llama.modeling_llama import LlamaAttention, eager_attention_forward
  30. from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
  31. class ParakeetEncoderRelPositionalEncoding(nn.Module):
  32. """Relative positional encoding for Parakeet."""
  33. inv_freq: torch.Tensor # fix linting for `register_buffer`
  34. def __init__(self, config: ParakeetEncoderConfig, device=None):
  35. super().__init__()
  36. self.max_position_embeddings = config.max_position_embeddings
  37. base = 10000.0
  38. inv_freq = 1.0 / (
  39. base
  40. ** (
  41. torch.arange(0, config.hidden_size, 2, dtype=torch.int64).to(device=device, dtype=torch.float)
  42. / config.hidden_size
  43. )
  44. )
  45. self.register_buffer("inv_freq", inv_freq, persistent=False)
  46. @torch.no_grad()
  47. def forward(self, hidden_states: torch.Tensor):
  48. seq_length = hidden_states.shape[1]
  49. if seq_length > self.max_position_embeddings:
  50. raise ValueError(
  51. f"Sequence Length: {seq_length} has to be less or equal than "
  52. f"config.max_position_embeddings {self.max_position_embeddings}."
  53. )
  54. position_ids = torch.arange(seq_length - 1, -seq_length, -1, device=hidden_states.device)
  55. inv_freq_expanded = (
  56. self.inv_freq[None, :, None].float().expand(hidden_states.shape[0], -1, 1).to(hidden_states.device)
  57. )
  58. position_ids_expanded = position_ids[None, None, :].float()
  59. device_type = (
  60. hidden_states.device.type
  61. if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
  62. else "cpu"
  63. )
  64. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  65. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  66. sin = freqs.sin()
  67. cos = freqs.cos()
  68. # interleave sin and cos
  69. pos_embed = torch.stack([sin, cos], dim=-1)
  70. pos_embed = pos_embed.reshape(*pos_embed.shape[:-2], -1)
  71. return pos_embed.to(dtype=hidden_states.dtype)
  72. class ParakeetEncoderFeedForward(nn.Module):
  73. def __init__(self, config: ParakeetEncoderConfig):
  74. super().__init__()
  75. self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.attention_bias)
  76. self.activation = ACT2FN[config.hidden_act]
  77. self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.attention_bias)
  78. self.activation_dropout = config.activation_dropout
  79. def forward(self, hidden_states):
  80. hidden_states = self.activation(self.linear1(hidden_states))
  81. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  82. hidden_states = self.linear2(hidden_states)
  83. return hidden_states
  84. class ParakeetEncoderConvolutionModule(FastSpeech2ConformerConvolutionModule):
  85. def __init__(self, config: ParakeetEncoderConfig, module_config=None):
  86. super().__init__(config, module_config)
  87. class ParakeetEncoderAttention(LlamaAttention):
  88. """Multi-head attention with relative positional encoding. See section 3.3 of https://huggingface.co/papers/1901.02860."""
  89. def __init__(self, config: ParakeetEncoderConfig, layer_idx: int):
  90. super().__init__(config, layer_idx=layer_idx)
  91. self.is_causal = False
  92. # W_{k,R} projection
  93. self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  94. # global content bias
  95. self.bias_u = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim))
  96. # global positional bias
  97. self.bias_v = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim))
  98. def forward(
  99. self,
  100. hidden_states: torch.Tensor,
  101. position_embeddings: Optional[torch.Tensor],
  102. attention_mask: Optional[torch.Tensor] = None,
  103. **kwargs: Unpack[TransformersKwargs],
  104. ) -> tuple[torch.Tensor, torch.Tensor]:
  105. input_shape = hidden_states.shape[:-1]
  106. batch_size, seq_length = input_shape
  107. hidden_shape = (batch_size, seq_length, -1, self.head_dim)
  108. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  109. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  110. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  111. attention_interface: Callable = eager_attention_forward
  112. if self.config._attn_implementation != "eager":
  113. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  114. query_states_with_bias_u = query_states + self.bias_u.view(
  115. 1, self.config.num_attention_heads, 1, self.head_dim
  116. )
  117. query_states_with_bias_v = query_states + self.bias_v.view(
  118. 1, self.config.num_attention_heads, 1, self.head_dim
  119. )
  120. relative_key_states = self.relative_k_proj(position_embeddings)
  121. relative_key_states = relative_key_states.view(batch_size, -1, self.config.num_attention_heads, self.head_dim)
  122. # terms (b) and (d)
  123. matrix_bd = query_states_with_bias_v @ relative_key_states.permute(0, 2, 3, 1)
  124. matrix_bd = self._rel_shift(matrix_bd)
  125. matrix_bd = matrix_bd[..., :seq_length]
  126. matrix_bd = matrix_bd * self.scaling
  127. if attention_mask is not None:
  128. # here the original codebase uses -10000.0 rather than float("-inf") and then manual masked fill with 0.0s
  129. # see: https://github.com/NVIDIA-NeMo/NeMo/blob/8cfedd7203462cb251a914e700e5605444277561/nemo/collections/asr/parts/submodules/multi_head_attention.py#L320-L340
  130. # we rather went for a straight-forward approach with float("-inf")
  131. matrix_bd = matrix_bd.masked_fill_(attention_mask.logical_not(), float("-inf"))
  132. # will compute matrix_ac - terms (a) and (c) - and add matrix_bd
  133. attn_output, attn_weights = attention_interface(
  134. self,
  135. query=query_states_with_bias_u,
  136. key=key_states,
  137. value=value_states,
  138. attention_mask=matrix_bd,
  139. dropout=0.0 if not self.training else self.attention_dropout,
  140. scaling=self.scaling,
  141. **kwargs,
  142. )
  143. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  144. attn_output = self.o_proj(attn_output)
  145. return attn_output, attn_weights
  146. def _rel_shift(self, attention_scores):
  147. """Relative position shift for Shaw et al. style attention. See appendix B of https://huggingface.co/papers/1901.02860."""
  148. batch_size, num_heads, query_length, position_length = attention_scores.shape
  149. attention_scores = nn.functional.pad(attention_scores, pad=(1, 0))
  150. attention_scores = attention_scores.view(batch_size, num_heads, -1, query_length)
  151. attention_scores = attention_scores[:, :, 1:].view(batch_size, num_heads, query_length, position_length)
  152. return attention_scores
  153. class ParakeetEncoderSubsamplingConv2D(nn.Module):
  154. def __init__(self, config: ParakeetEncoderConfig):
  155. super().__init__()
  156. self.kernel_size = config.subsampling_conv_kernel_size
  157. self.stride = config.subsampling_conv_stride
  158. self.channels = config.subsampling_conv_channels
  159. self.padding = (self.kernel_size - 1) // 2
  160. self.num_layers = int(math.log2(config.subsampling_factor))
  161. # define layers
  162. self.layers = nn.ModuleList()
  163. self.layers.append(
  164. nn.Conv2d(1, self.channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
  165. )
  166. self.layers.append(nn.ReLU())
  167. for i in range(self.num_layers - 1):
  168. # depthwise conv
  169. self.layers.append(
  170. nn.Conv2d(
  171. self.channels,
  172. self.channels,
  173. kernel_size=self.kernel_size,
  174. stride=self.stride,
  175. padding=self.padding,
  176. groups=self.channels,
  177. )
  178. )
  179. # pointwise conv
  180. self.layers.append(nn.Conv2d(self.channels, self.channels, kernel_size=1))
  181. # activation
  182. self.layers.append(nn.ReLU())
  183. out_length = config.num_mel_bins // (self.stride**self.num_layers)
  184. self.linear = nn.Linear(config.subsampling_conv_channels * out_length, config.hidden_size, bias=True)
  185. def _get_output_length(self, input_lengths: torch.Tensor, conv_layer: nn.Conv2d):
  186. if hasattr(conv_layer, "stride") and conv_layer.stride != (1, 1):
  187. padding = conv_layer.padding
  188. kernel_size = conv_layer.kernel_size[0]
  189. stride = conv_layer.stride[0]
  190. output_lengths = (input_lengths + padding[0] + padding[1] - kernel_size) // stride + 1
  191. return output_lengths
  192. return input_lengths
  193. def forward(self, input_features: torch.Tensor, attention_mask: torch.Tensor = None):
  194. hidden_states = input_features.unsqueeze(1)
  195. current_lengths = attention_mask.sum(-1) if attention_mask is not None else None
  196. for layer in self.layers:
  197. hidden_states = layer(hidden_states)
  198. # mask the hidden states
  199. if isinstance(layer, nn.Conv2d) and attention_mask is not None:
  200. current_lengths = self._get_output_length(current_lengths, layer)
  201. current_seq_length = hidden_states.shape[2]
  202. channel_mask = (
  203. torch.arange(current_seq_length, device=attention_mask.device) < current_lengths[:, None]
  204. )
  205. hidden_states *= channel_mask[:, None, :, None]
  206. hidden_states = hidden_states.transpose(1, 2).reshape(hidden_states.shape[0], hidden_states.shape[2], -1)
  207. hidden_states = self.linear(hidden_states)
  208. return hidden_states
  209. class ParakeetEncoderBlock(GradientCheckpointingLayer):
  210. def __init__(self, config: ParakeetEncoderConfig, layer_idx: Optional[int] = None):
  211. super().__init__()
  212. self.gradient_checkpointing = False
  213. self.feed_forward1 = ParakeetEncoderFeedForward(config)
  214. self.self_attn = ParakeetEncoderAttention(config, layer_idx)
  215. self.conv = ParakeetEncoderConvolutionModule(config)
  216. self.feed_forward2 = ParakeetEncoderFeedForward(config)
  217. self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size)
  218. self.norm_self_att = nn.LayerNorm(config.hidden_size)
  219. self.norm_conv = nn.LayerNorm(config.hidden_size)
  220. self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size)
  221. self.norm_out = nn.LayerNorm(config.hidden_size)
  222. def forward(
  223. self,
  224. hidden_states: torch.Tensor,
  225. attention_mask: Optional[torch.Tensor] = None,
  226. position_embeddings: Optional[torch.Tensor] = None,
  227. **kwargs: Unpack[TransformersKwargs],
  228. ) -> torch.Tensor:
  229. residual = hidden_states
  230. hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
  231. hidden_states = residual + 0.5 * hidden_states # the conformer architecture uses a factor of 0.5
  232. normalized_hidden_states = self.norm_self_att(hidden_states)
  233. attn_output, _ = self.self_attn(
  234. hidden_states=normalized_hidden_states,
  235. attention_mask=attention_mask,
  236. position_embeddings=position_embeddings,
  237. **kwargs,
  238. )
  239. hidden_states = hidden_states + attn_output
  240. conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
  241. hidden_states = hidden_states + conv_output
  242. ff2_output = self.feed_forward2(self.norm_feed_forward2(hidden_states))
  243. hidden_states = hidden_states + 0.5 * ff2_output # the conformer architecture uses a factor of 0.5
  244. hidden_states = self.norm_out(hidden_states)
  245. return hidden_states
  246. @auto_docstring
  247. class ParakeetPreTrainedModel(PreTrainedModel):
  248. config: ParakeetCTCConfig
  249. base_model_prefix = "model"
  250. main_input_name = "input_features"
  251. supports_gradient_checkpointing = True
  252. _no_split_modules = ["ParakeetEncoderBlock"]
  253. _supports_flat_attention_mask = True
  254. _supports_sdpa = True
  255. _supports_flex_attn = True
  256. # TODO: @eustlb, add support when flash attention supports custom attention bias
  257. _supports_flash_attn = False
  258. _can_compile_fullgraph = True
  259. _supports_attention_backend = True
  260. _can_record_outputs = {
  261. "hidden_states": ParakeetEncoderBlock,
  262. "attentions": ParakeetEncoderAttention,
  263. }
  264. def _init_weights(self, module):
  265. super()._init_weights(module)
  266. if hasattr(self.config, "initializer_range"):
  267. std = self.config.initializer_range
  268. else:
  269. # 0.02 is the standard default value accross the library
  270. std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
  271. if isinstance(module, ParakeetEncoderAttention):
  272. # Initialize positional bias parameters
  273. module.bias_u.data.normal_(mean=0.0, std=std)
  274. module.bias_v.data.normal_(mean=0.0, std=std)
  275. def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
  276. encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config
  277. kernel_size = encoder_config.subsampling_conv_kernel_size
  278. stride = encoder_config.subsampling_conv_stride
  279. num_layers = int(math.log2(encoder_config.subsampling_factor))
  280. all_paddings = (kernel_size - 1) // 2 * 2
  281. add_pad = all_paddings - kernel_size
  282. lengths = input_lengths
  283. for _ in range(num_layers):
  284. lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + 1.0
  285. lengths = torch.floor(lengths)
  286. return lengths.to(dtype=torch.int)
  287. def _get_output_attention_mask(self, attention_mask: torch.Tensor, target_length: Optional[int] = None):
  288. """
  289. Convert the input attention mask to its subsampled form. `target_length` sets the desired output length, useful
  290. when the attention mask length differs from `sum(-1).max()` (i.e., when the longest sequence in the batch is padded)
  291. """
  292. output_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
  293. # Use target_length if provided, otherwise use max length in batch
  294. max_length = target_length if target_length is not None else output_lengths.max()
  295. attention_mask = torch.arange(max_length, device=attention_mask.device) < output_lengths[:, None]
  296. return attention_mask
  297. @auto_docstring(
  298. custom_intro="""
  299. The Parakeet Encoder model, based on the [Fast Conformer architecture](https://huggingface.co/papers/2305.05084).
  300. """
  301. )
  302. class ParakeetEncoder(ParakeetPreTrainedModel):
  303. config: ParakeetEncoderConfig
  304. base_model_prefix = "encoder"
  305. def __init__(self, config: ParakeetEncoderConfig):
  306. super().__init__(config)
  307. self.config = config
  308. self.gradient_checkpointing = False
  309. self.dropout = config.dropout
  310. self.dropout_positions = config.dropout_positions
  311. self.layerdrop = config.layerdrop
  312. self.input_scale = math.sqrt(config.hidden_size) if config.scale_input else 1.0
  313. self.subsampling = ParakeetEncoderSubsamplingConv2D(config)
  314. self.encode_positions = ParakeetEncoderRelPositionalEncoding(config)
  315. self.layers = nn.ModuleList(
  316. [ParakeetEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  317. )
  318. self.post_init()
  319. @auto_docstring
  320. @check_model_inputs()
  321. @can_return_tuple
  322. def forward(
  323. self,
  324. input_features: torch.Tensor,
  325. attention_mask: Optional[torch.Tensor] = None,
  326. **kwargs: Unpack[TransformersKwargs],
  327. ) -> BaseModelOutput:
  328. r"""
  329. Example:
  330. ```python
  331. >>> from transformers import AutoProcessor, ParakeetEncoder
  332. >>> from datasets import load_dataset, Audio
  333. >>> model_id = "nvidia/parakeet-ctc-1.1b"
  334. >>> processor = AutoProcessor.from_pretrained(model_id)
  335. >>> encoder = ParakeetEncoder.from_pretrained(model_id)
  336. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  337. >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
  338. >>> inputs = processor(ds[0]["audio"]["array"])
  339. >>> encoder_outputs = encoder(**inputs)
  340. >>> print(encoder_outputs.last_hidden_state.shape)
  341. ```
  342. """
  343. hidden_states = self.subsampling(input_features, attention_mask)
  344. hidden_states = hidden_states * self.input_scale
  345. position_embeddings = self.encode_positions(hidden_states)
  346. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  347. position_embeddings = nn.functional.dropout(
  348. position_embeddings, p=self.dropout_positions, training=self.training
  349. )
  350. if attention_mask is not None:
  351. attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
  352. attention_mask = attention_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1)
  353. attention_mask = attention_mask & attention_mask.transpose(1, 2)
  354. attention_mask = attention_mask.unsqueeze(1)
  355. for encoder_layer in self.layers:
  356. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  357. to_drop = False
  358. if self.training:
  359. dropout_probability = torch.rand([])
  360. if dropout_probability < self.layerdrop: # skip the layer
  361. to_drop = True
  362. if not to_drop:
  363. hidden_states = encoder_layer(
  364. hidden_states,
  365. attention_mask=attention_mask,
  366. position_embeddings=position_embeddings,
  367. **kwargs,
  368. )
  369. return BaseModelOutput(last_hidden_state=hidden_states)
  370. @dataclass
  371. class ParakeetGenerateOutput(ModelOutput):
  372. """
  373. Outputs of Parakeet models.
  374. Args:
  375. sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  376. The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
  377. if all batches finished early due to the `eos_token_id`.
  378. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
  379. Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
  380. at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
  381. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
  382. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
  383. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  384. `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
  385. hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
  386. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  387. `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
  388. """
  389. sequences: torch.LongTensor
  390. logits: Optional[tuple[torch.FloatTensor]] = None
  391. attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
  392. hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
  393. @auto_docstring(
  394. custom_intro="""
  395. Parakeet Encoder with a Connectionist Temporal Classification (CTC) head.
  396. """
  397. )
  398. class ParakeetForCTC(ParakeetPreTrainedModel):
  399. config: ParakeetCTCConfig
  400. def __init__(self, config: ParakeetCTCConfig):
  401. super().__init__(config)
  402. self.encoder = ParakeetEncoder(config.encoder_config)
  403. # Conv rather than linear to be consistent with NeMO decoding layer
  404. self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1)
  405. self.post_init()
  406. @auto_docstring
  407. @can_return_tuple
  408. def forward(
  409. self,
  410. input_features: torch.Tensor,
  411. attention_mask: Optional[torch.Tensor] = None,
  412. labels: Optional[torch.Tensor] = None,
  413. **kwargs: Unpack[TransformersKwargs],
  414. ) -> CausalLMOutput:
  415. r"""
  416. Example:
  417. ```python
  418. >>> from transformers import AutoProcessor, ParakeetForCTC
  419. >>> from datasets import load_dataset, Audio
  420. >>> model_id = "nvidia/parakeet-ctc-1.1b"
  421. >>> processor = AutoProcessor.from_pretrained(model_id)
  422. >>> model = ParakeetForCTC.from_pretrained(model_id)
  423. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  424. >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
  425. >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
  426. >>> outputs = model(**inputs)
  427. >>> print(outputs.loss)
  428. ```"""
  429. encoder_outputs = self.encoder(
  430. input_features=input_features,
  431. attention_mask=attention_mask,
  432. **kwargs,
  433. )
  434. hidden_states = encoder_outputs.last_hidden_state
  435. logits = self.ctc_head(hidden_states.transpose(1, 2)).transpose(1, 2)
  436. loss = None
  437. if labels is not None:
  438. # retrieve loss input_lengths from attention_mask
  439. attention_mask = (
  440. attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long)
  441. )
  442. input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
  443. # assuming that padded tokens are filled with -100
  444. # when not being attended to
  445. labels_mask = labels != self.config.pad_token_id
  446. target_lengths = labels_mask.sum(-1)
  447. flattened_targets = labels.masked_select(labels_mask)
  448. # ctc_loss doesn't support fp16
  449. log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
  450. with torch.backends.cudnn.flags(enabled=False):
  451. loss = nn.functional.ctc_loss(
  452. log_probs,
  453. flattened_targets,
  454. input_lengths,
  455. target_lengths,
  456. blank=self.config.pad_token_id,
  457. reduction=self.config.ctc_loss_reduction,
  458. zero_infinity=self.config.ctc_zero_infinity,
  459. )
  460. return CausalLMOutput(
  461. loss=loss,
  462. logits=logits,
  463. hidden_states=encoder_outputs.hidden_states,
  464. attentions=encoder_outputs.attentions,
  465. )
  466. @torch.no_grad()
  467. def generate(
  468. self,
  469. input_features: torch.Tensor,
  470. attention_mask: Optional[torch.Tensor] = None,
  471. return_dict_in_generate: bool = False,
  472. **kwargs: Unpack[TransformersKwargs],
  473. ) -> Union[ParakeetGenerateOutput, torch.LongTensor]:
  474. r"""
  475. Example:
  476. ```python
  477. >>> from transformers import AutoProcessor, ParakeetForCTC
  478. >>> from datasets import load_dataset, Audio
  479. >>> model_id = "nvidia/parakeet-ctc-1.1b"
  480. >>> processor = AutoProcessor.from_pretrained(model_id)
  481. >>> model = ParakeetForCTC.from_pretrained(model_id)
  482. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  483. >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
  484. >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
  485. >>> predicted_ids = model.generate(**inputs)
  486. >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
  487. >>> print(transcription)
  488. ```
  489. """
  490. kwargs["return_dict"] = True
  491. outputs: CausalLMOutput = self.forward(
  492. input_features=input_features,
  493. attention_mask=attention_mask,
  494. **kwargs,
  495. )
  496. # greedy decoding
  497. sequences = outputs.logits.argmax(dim=-1)
  498. # mask out padded tokens
  499. if attention_mask is not None:
  500. attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequences.shape[1])
  501. sequences[~attention_mask] = self.config.pad_token_id
  502. if return_dict_in_generate:
  503. return ParakeetGenerateOutput(
  504. sequences=sequences,
  505. logits=outputs.logits,
  506. attentions=outputs.attentions,
  507. hidden_states=outputs.hidden_states,
  508. )
  509. return sequences
  510. __all__ = ["ParakeetForCTC", "ParakeetEncoder", "ParakeetPreTrainedModel"]