modeling_parakeet.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/parakeet/modular_parakeet.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_parakeet.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import math
  22. from dataclasses import dataclass
  23. from typing import Callable, Optional, Union
  24. import torch
  25. from torch import nn
  26. from ...activations import ACT2FN
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import BaseModelOutput, CausalLMOutput
  29. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  30. from ...processing_utils import Unpack
  31. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
  32. from ...utils.deprecation import deprecate_kwarg
  33. from ...utils.generic import check_model_inputs
  34. from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
  35. class ParakeetEncoderRelPositionalEncoding(nn.Module):
  36. """Relative positional encoding for Parakeet."""
  37. inv_freq: torch.Tensor # fix linting for `register_buffer`
  38. def __init__(self, config: ParakeetEncoderConfig, device=None):
  39. super().__init__()
  40. self.max_position_embeddings = config.max_position_embeddings
  41. base = 10000.0
  42. inv_freq = 1.0 / (
  43. base
  44. ** (
  45. torch.arange(0, config.hidden_size, 2, dtype=torch.int64).to(device=device, dtype=torch.float)
  46. / config.hidden_size
  47. )
  48. )
  49. self.register_buffer("inv_freq", inv_freq, persistent=False)
  50. @torch.no_grad()
  51. def forward(self, hidden_states: torch.Tensor):
  52. seq_length = hidden_states.shape[1]
  53. if seq_length > self.max_position_embeddings:
  54. raise ValueError(
  55. f"Sequence Length: {seq_length} has to be less or equal than "
  56. f"config.max_position_embeddings {self.max_position_embeddings}."
  57. )
  58. position_ids = torch.arange(seq_length - 1, -seq_length, -1, device=hidden_states.device)
  59. inv_freq_expanded = (
  60. self.inv_freq[None, :, None].float().expand(hidden_states.shape[0], -1, 1).to(hidden_states.device)
  61. )
  62. position_ids_expanded = position_ids[None, None, :].float()
  63. device_type = (
  64. hidden_states.device.type
  65. if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
  66. else "cpu"
  67. )
  68. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  69. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  70. sin = freqs.sin()
  71. cos = freqs.cos()
  72. # interleave sin and cos
  73. pos_embed = torch.stack([sin, cos], dim=-1)
  74. pos_embed = pos_embed.reshape(*pos_embed.shape[:-2], -1)
  75. return pos_embed.to(dtype=hidden_states.dtype)
  76. class ParakeetEncoderFeedForward(nn.Module):
  77. def __init__(self, config: ParakeetEncoderConfig):
  78. super().__init__()
  79. self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.attention_bias)
  80. self.activation = ACT2FN[config.hidden_act]
  81. self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.attention_bias)
  82. self.activation_dropout = config.activation_dropout
  83. def forward(self, hidden_states):
  84. hidden_states = self.activation(self.linear1(hidden_states))
  85. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  86. hidden_states = self.linear2(hidden_states)
  87. return hidden_states
  88. class ParakeetEncoderConvolutionModule(nn.Module):
  89. def __init__(self, config: ParakeetEncoderConfig, module_config=None):
  90. """
  91. Args:
  92. config (ParakeetEncoderConfig): Configuration for the model.
  93. module_config (dict): Configuration for the module (e.g., encoder or decoder).
  94. """
  95. super().__init__()
  96. channels = config.hidden_size
  97. # kernel_size should be an odd number for 'SAME' padding
  98. if module_config is None:
  99. # e.g. using `ParakeetEncoderEncoderConfig` in src/transformers/models/parakeet_encoder/configuration_parakeet_encoder.py
  100. kernel_size = config.conv_kernel_size
  101. self.activation = ACT2FN[getattr(config, "hidden_act", "silu")]
  102. else:
  103. kernel_size = module_config["kernel_size"]
  104. self.activation = ACT2FN[module_config.get("activation", "silu")]
  105. self.padding = (kernel_size - 1) // 2
  106. self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=True)
  107. self.depthwise_conv = nn.Conv1d(
  108. channels, channels, kernel_size, stride=1, padding=self.padding, groups=channels, bias=True
  109. )
  110. self.norm = nn.BatchNorm1d(channels)
  111. self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=True)
  112. def forward(self, hidden_states, attention_mask=None):
  113. """
  114. Compute convolution module.
  115. Args:
  116. hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
  117. attention_mask (`torch.Tensor` of shape `(batch, 1, time)`): Attention mask.
  118. Returns:
  119. `torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
  120. """
  121. # exchange the temporal dimension and the feature dimension
  122. hidden_states = hidden_states.transpose(1, 2)
  123. # GLU mechanism, (batch_size, 2*channel, dim)
  124. hidden_states = self.pointwise_conv1(hidden_states)
  125. # (batch_size, channel, dim)
  126. hidden_states = nn.functional.glu(hidden_states, dim=1)
  127. # Apply padding mask before convolution
  128. if attention_mask is not None:
  129. all_masked_rows = torch.all(~attention_mask, dim=-1)
  130. hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
  131. # 1D Depthwise Conv
  132. hidden_states = self.depthwise_conv(hidden_states)
  133. hidden_states = self.norm(hidden_states)
  134. hidden_states = self.activation(hidden_states)
  135. hidden_states = self.pointwise_conv2(hidden_states)
  136. return hidden_states.transpose(1, 2)
  137. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  138. """
  139. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  140. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  141. """
  142. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  143. if n_rep == 1:
  144. return hidden_states
  145. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  146. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  147. def eager_attention_forward(
  148. module: nn.Module,
  149. query: torch.Tensor,
  150. key: torch.Tensor,
  151. value: torch.Tensor,
  152. attention_mask: Optional[torch.Tensor],
  153. scaling: float,
  154. dropout: float = 0.0,
  155. **kwargs: Unpack[TransformersKwargs],
  156. ):
  157. key_states = repeat_kv(key, module.num_key_value_groups)
  158. value_states = repeat_kv(value, module.num_key_value_groups)
  159. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  160. if attention_mask is not None:
  161. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  162. attn_weights = attn_weights + causal_mask
  163. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  164. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  165. attn_output = torch.matmul(attn_weights, value_states)
  166. attn_output = attn_output.transpose(1, 2).contiguous()
  167. return attn_output, attn_weights
  168. class ParakeetEncoderAttention(nn.Module):
  169. """Multi-head attention with relative positional encoding. See section 3.3 of https://huggingface.co/papers/1901.02860."""
  170. def __init__(self, config: ParakeetEncoderConfig, layer_idx: int):
  171. super().__init__()
  172. self.config = config
  173. self.layer_idx = layer_idx
  174. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  175. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  176. self.scaling = self.head_dim**-0.5
  177. self.attention_dropout = config.attention_dropout
  178. self.is_causal = False
  179. self.q_proj = nn.Linear(
  180. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  181. )
  182. self.k_proj = nn.Linear(
  183. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  184. )
  185. self.v_proj = nn.Linear(
  186. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  187. )
  188. self.o_proj = nn.Linear(
  189. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  190. )
  191. # W_{k,R} projection
  192. self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
  193. # global content bias
  194. self.bias_u = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim))
  195. # global positional bias
  196. self.bias_v = nn.Parameter(torch.zeros(config.num_attention_heads, self.head_dim))
  197. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  198. def forward(
  199. self,
  200. hidden_states: torch.Tensor,
  201. position_embeddings: Optional[torch.Tensor],
  202. attention_mask: Optional[torch.Tensor] = None,
  203. **kwargs: Unpack[TransformersKwargs],
  204. ) -> tuple[torch.Tensor, torch.Tensor]:
  205. input_shape = hidden_states.shape[:-1]
  206. batch_size, seq_length = input_shape
  207. hidden_shape = (batch_size, seq_length, -1, self.head_dim)
  208. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  209. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  210. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  211. attention_interface: Callable = eager_attention_forward
  212. if self.config._attn_implementation != "eager":
  213. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  214. query_states_with_bias_u = query_states + self.bias_u.view(
  215. 1, self.config.num_attention_heads, 1, self.head_dim
  216. )
  217. query_states_with_bias_v = query_states + self.bias_v.view(
  218. 1, self.config.num_attention_heads, 1, self.head_dim
  219. )
  220. relative_key_states = self.relative_k_proj(position_embeddings)
  221. relative_key_states = relative_key_states.view(batch_size, -1, self.config.num_attention_heads, self.head_dim)
  222. # terms (b) and (d)
  223. matrix_bd = query_states_with_bias_v @ relative_key_states.permute(0, 2, 3, 1)
  224. matrix_bd = self._rel_shift(matrix_bd)
  225. matrix_bd = matrix_bd[..., :seq_length]
  226. matrix_bd = matrix_bd * self.scaling
  227. if attention_mask is not None:
  228. # here the original codebase uses -10000.0 rather than float("-inf") and then manual masked fill with 0.0s
  229. # see: https://github.com/NVIDIA-NeMo/NeMo/blob/8cfedd7203462cb251a914e700e5605444277561/nemo/collections/asr/parts/submodules/multi_head_attention.py#L320-L340
  230. # we rather went for a straight-forward approach with float("-inf")
  231. matrix_bd = matrix_bd.masked_fill_(attention_mask.logical_not(), float("-inf"))
  232. # will compute matrix_ac - terms (a) and (c) - and add matrix_bd
  233. attn_output, attn_weights = attention_interface(
  234. self,
  235. query=query_states_with_bias_u,
  236. key=key_states,
  237. value=value_states,
  238. attention_mask=matrix_bd,
  239. dropout=0.0 if not self.training else self.attention_dropout,
  240. scaling=self.scaling,
  241. **kwargs,
  242. )
  243. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  244. attn_output = self.o_proj(attn_output)
  245. return attn_output, attn_weights
  246. def _rel_shift(self, attention_scores):
  247. """Relative position shift for Shaw et al. style attention. See appendix B of https://huggingface.co/papers/1901.02860."""
  248. batch_size, num_heads, query_length, position_length = attention_scores.shape
  249. attention_scores = nn.functional.pad(attention_scores, pad=(1, 0))
  250. attention_scores = attention_scores.view(batch_size, num_heads, -1, query_length)
  251. attention_scores = attention_scores[:, :, 1:].view(batch_size, num_heads, query_length, position_length)
  252. return attention_scores
  253. class ParakeetEncoderSubsamplingConv2D(nn.Module):
  254. def __init__(self, config: ParakeetEncoderConfig):
  255. super().__init__()
  256. self.kernel_size = config.subsampling_conv_kernel_size
  257. self.stride = config.subsampling_conv_stride
  258. self.channels = config.subsampling_conv_channels
  259. self.padding = (self.kernel_size - 1) // 2
  260. self.num_layers = int(math.log2(config.subsampling_factor))
  261. # define layers
  262. self.layers = nn.ModuleList()
  263. self.layers.append(
  264. nn.Conv2d(1, self.channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding)
  265. )
  266. self.layers.append(nn.ReLU())
  267. for i in range(self.num_layers - 1):
  268. # depthwise conv
  269. self.layers.append(
  270. nn.Conv2d(
  271. self.channels,
  272. self.channels,
  273. kernel_size=self.kernel_size,
  274. stride=self.stride,
  275. padding=self.padding,
  276. groups=self.channels,
  277. )
  278. )
  279. # pointwise conv
  280. self.layers.append(nn.Conv2d(self.channels, self.channels, kernel_size=1))
  281. # activation
  282. self.layers.append(nn.ReLU())
  283. out_length = config.num_mel_bins // (self.stride**self.num_layers)
  284. self.linear = nn.Linear(config.subsampling_conv_channels * out_length, config.hidden_size, bias=True)
  285. def _get_output_length(self, input_lengths: torch.Tensor, conv_layer: nn.Conv2d):
  286. if hasattr(conv_layer, "stride") and conv_layer.stride != (1, 1):
  287. padding = conv_layer.padding
  288. kernel_size = conv_layer.kernel_size[0]
  289. stride = conv_layer.stride[0]
  290. output_lengths = (input_lengths + padding[0] + padding[1] - kernel_size) // stride + 1
  291. return output_lengths
  292. return input_lengths
  293. def forward(self, input_features: torch.Tensor, attention_mask: torch.Tensor = None):
  294. hidden_states = input_features.unsqueeze(1)
  295. current_lengths = attention_mask.sum(-1) if attention_mask is not None else None
  296. for layer in self.layers:
  297. hidden_states = layer(hidden_states)
  298. # mask the hidden states
  299. if isinstance(layer, nn.Conv2d) and attention_mask is not None:
  300. current_lengths = self._get_output_length(current_lengths, layer)
  301. current_seq_length = hidden_states.shape[2]
  302. channel_mask = (
  303. torch.arange(current_seq_length, device=attention_mask.device) < current_lengths[:, None]
  304. )
  305. hidden_states *= channel_mask[:, None, :, None]
  306. hidden_states = hidden_states.transpose(1, 2).reshape(hidden_states.shape[0], hidden_states.shape[2], -1)
  307. hidden_states = self.linear(hidden_states)
  308. return hidden_states
  309. class ParakeetEncoderBlock(GradientCheckpointingLayer):
  310. def __init__(self, config: ParakeetEncoderConfig, layer_idx: Optional[int] = None):
  311. super().__init__()
  312. self.gradient_checkpointing = False
  313. self.feed_forward1 = ParakeetEncoderFeedForward(config)
  314. self.self_attn = ParakeetEncoderAttention(config, layer_idx)
  315. self.conv = ParakeetEncoderConvolutionModule(config)
  316. self.feed_forward2 = ParakeetEncoderFeedForward(config)
  317. self.norm_feed_forward1 = nn.LayerNorm(config.hidden_size)
  318. self.norm_self_att = nn.LayerNorm(config.hidden_size)
  319. self.norm_conv = nn.LayerNorm(config.hidden_size)
  320. self.norm_feed_forward2 = nn.LayerNorm(config.hidden_size)
  321. self.norm_out = nn.LayerNorm(config.hidden_size)
  322. def forward(
  323. self,
  324. hidden_states: torch.Tensor,
  325. attention_mask: Optional[torch.Tensor] = None,
  326. position_embeddings: Optional[torch.Tensor] = None,
  327. **kwargs: Unpack[TransformersKwargs],
  328. ) -> torch.Tensor:
  329. residual = hidden_states
  330. hidden_states = self.feed_forward1(self.norm_feed_forward1(hidden_states))
  331. hidden_states = residual + 0.5 * hidden_states # the conformer architecture uses a factor of 0.5
  332. normalized_hidden_states = self.norm_self_att(hidden_states)
  333. attn_output, _ = self.self_attn(
  334. hidden_states=normalized_hidden_states,
  335. attention_mask=attention_mask,
  336. position_embeddings=position_embeddings,
  337. **kwargs,
  338. )
  339. hidden_states = hidden_states + attn_output
  340. conv_output = self.conv(self.norm_conv(hidden_states), attention_mask=attention_mask)
  341. hidden_states = hidden_states + conv_output
  342. ff2_output = self.feed_forward2(self.norm_feed_forward2(hidden_states))
  343. hidden_states = hidden_states + 0.5 * ff2_output # the conformer architecture uses a factor of 0.5
  344. hidden_states = self.norm_out(hidden_states)
  345. return hidden_states
  346. @auto_docstring
  347. class ParakeetPreTrainedModel(PreTrainedModel):
  348. config: ParakeetCTCConfig
  349. base_model_prefix = "model"
  350. main_input_name = "input_features"
  351. supports_gradient_checkpointing = True
  352. _no_split_modules = ["ParakeetEncoderBlock"]
  353. _supports_flat_attention_mask = True
  354. _supports_sdpa = True
  355. _supports_flex_attn = True
  356. # TODO: @eustlb, add support when flash attention supports custom attention bias
  357. _supports_flash_attn = False
  358. _can_compile_fullgraph = True
  359. _supports_attention_backend = True
  360. _can_record_outputs = {
  361. "hidden_states": ParakeetEncoderBlock,
  362. "attentions": ParakeetEncoderAttention,
  363. }
  364. def _init_weights(self, module):
  365. super()._init_weights(module)
  366. if hasattr(self.config, "initializer_range"):
  367. std = self.config.initializer_range
  368. else:
  369. # 0.02 is the standard default value accross the library
  370. std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
  371. if isinstance(module, ParakeetEncoderAttention):
  372. # Initialize positional bias parameters
  373. module.bias_u.data.normal_(mean=0.0, std=std)
  374. module.bias_v.data.normal_(mean=0.0, std=std)
  375. def _get_subsampling_output_length(self, input_lengths: torch.Tensor):
  376. encoder_config = self.config.encoder_config if isinstance(self.config, ParakeetCTCConfig) else self.config
  377. kernel_size = encoder_config.subsampling_conv_kernel_size
  378. stride = encoder_config.subsampling_conv_stride
  379. num_layers = int(math.log2(encoder_config.subsampling_factor))
  380. all_paddings = (kernel_size - 1) // 2 * 2
  381. add_pad = all_paddings - kernel_size
  382. lengths = input_lengths
  383. for _ in range(num_layers):
  384. lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + 1.0
  385. lengths = torch.floor(lengths)
  386. return lengths.to(dtype=torch.int)
  387. def _get_output_attention_mask(self, attention_mask: torch.Tensor, target_length: Optional[int] = None):
  388. """
  389. Convert the input attention mask to its subsampled form. `target_length` sets the desired output length, useful
  390. when the attention mask length differs from `sum(-1).max()` (i.e., when the longest sequence in the batch is padded)
  391. """
  392. output_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
  393. # Use target_length if provided, otherwise use max length in batch
  394. max_length = target_length if target_length is not None else output_lengths.max()
  395. attention_mask = torch.arange(max_length, device=attention_mask.device) < output_lengths[:, None]
  396. return attention_mask
  397. @auto_docstring(
  398. custom_intro="""
  399. The Parakeet Encoder model, based on the [Fast Conformer architecture](https://huggingface.co/papers/2305.05084).
  400. """
  401. )
  402. class ParakeetEncoder(ParakeetPreTrainedModel):
  403. config: ParakeetEncoderConfig
  404. base_model_prefix = "encoder"
  405. def __init__(self, config: ParakeetEncoderConfig):
  406. super().__init__(config)
  407. self.config = config
  408. self.gradient_checkpointing = False
  409. self.dropout = config.dropout
  410. self.dropout_positions = config.dropout_positions
  411. self.layerdrop = config.layerdrop
  412. self.input_scale = math.sqrt(config.hidden_size) if config.scale_input else 1.0
  413. self.subsampling = ParakeetEncoderSubsamplingConv2D(config)
  414. self.encode_positions = ParakeetEncoderRelPositionalEncoding(config)
  415. self.layers = nn.ModuleList(
  416. [ParakeetEncoderBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  417. )
  418. self.post_init()
  419. @auto_docstring
  420. @check_model_inputs()
  421. @can_return_tuple
  422. def forward(
  423. self,
  424. input_features: torch.Tensor,
  425. attention_mask: Optional[torch.Tensor] = None,
  426. **kwargs: Unpack[TransformersKwargs],
  427. ) -> BaseModelOutput:
  428. r"""
  429. Example:
  430. ```python
  431. >>> from transformers import AutoProcessor, ParakeetEncoder
  432. >>> from datasets import load_dataset, Audio
  433. >>> model_id = "nvidia/parakeet-ctc-1.1b"
  434. >>> processor = AutoProcessor.from_pretrained(model_id)
  435. >>> encoder = ParakeetEncoder.from_pretrained(model_id)
  436. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  437. >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
  438. >>> inputs = processor(ds[0]["audio"]["array"])
  439. >>> encoder_outputs = encoder(**inputs)
  440. >>> print(encoder_outputs.last_hidden_state.shape)
  441. ```
  442. """
  443. hidden_states = self.subsampling(input_features, attention_mask)
  444. hidden_states = hidden_states * self.input_scale
  445. position_embeddings = self.encode_positions(hidden_states)
  446. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  447. position_embeddings = nn.functional.dropout(
  448. position_embeddings, p=self.dropout_positions, training=self.training
  449. )
  450. if attention_mask is not None:
  451. attention_mask = self._get_output_attention_mask(attention_mask, target_length=hidden_states.shape[1])
  452. attention_mask = attention_mask.unsqueeze(1).expand(-1, hidden_states.shape[1], -1)
  453. attention_mask = attention_mask & attention_mask.transpose(1, 2)
  454. attention_mask = attention_mask.unsqueeze(1)
  455. for encoder_layer in self.layers:
  456. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  457. to_drop = False
  458. if self.training:
  459. dropout_probability = torch.rand([])
  460. if dropout_probability < self.layerdrop: # skip the layer
  461. to_drop = True
  462. if not to_drop:
  463. hidden_states = encoder_layer(
  464. hidden_states,
  465. attention_mask=attention_mask,
  466. position_embeddings=position_embeddings,
  467. **kwargs,
  468. )
  469. return BaseModelOutput(last_hidden_state=hidden_states)
  470. @dataclass
  471. class ParakeetGenerateOutput(ModelOutput):
  472. """
  473. Outputs of Parakeet models.
  474. Args:
  475. sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  476. The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
  477. if all batches finished early due to the `eos_token_id`.
  478. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
  479. Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
  480. at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
  481. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
  482. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
  483. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  484. `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
  485. hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
  486. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  487. `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
  488. """
  489. sequences: torch.LongTensor
  490. logits: Optional[tuple[torch.FloatTensor]] = None
  491. attentions: Optional[tuple[tuple[torch.FloatTensor]]] = None
  492. hidden_states: Optional[tuple[tuple[torch.FloatTensor]]] = None
  493. @auto_docstring(
  494. custom_intro="""
  495. Parakeet Encoder with a Connectionist Temporal Classification (CTC) head.
  496. """
  497. )
  498. class ParakeetForCTC(ParakeetPreTrainedModel):
  499. config: ParakeetCTCConfig
  500. def __init__(self, config: ParakeetCTCConfig):
  501. super().__init__(config)
  502. self.encoder = ParakeetEncoder(config.encoder_config)
  503. # Conv rather than linear to be consistent with NeMO decoding layer
  504. self.ctc_head = nn.Conv1d(config.encoder_config.hidden_size, config.vocab_size, kernel_size=1)
  505. self.post_init()
  506. @auto_docstring
  507. @can_return_tuple
  508. def forward(
  509. self,
  510. input_features: torch.Tensor,
  511. attention_mask: Optional[torch.Tensor] = None,
  512. labels: Optional[torch.Tensor] = None,
  513. **kwargs: Unpack[TransformersKwargs],
  514. ) -> CausalLMOutput:
  515. r"""
  516. Example:
  517. ```python
  518. >>> from transformers import AutoProcessor, ParakeetForCTC
  519. >>> from datasets import load_dataset, Audio
  520. >>> model_id = "nvidia/parakeet-ctc-1.1b"
  521. >>> processor = AutoProcessor.from_pretrained(model_id)
  522. >>> model = ParakeetForCTC.from_pretrained(model_id)
  523. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  524. >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
  525. >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
  526. >>> outputs = model(**inputs)
  527. >>> print(outputs.loss)
  528. ```"""
  529. encoder_outputs = self.encoder(
  530. input_features=input_features,
  531. attention_mask=attention_mask,
  532. **kwargs,
  533. )
  534. hidden_states = encoder_outputs.last_hidden_state
  535. logits = self.ctc_head(hidden_states.transpose(1, 2)).transpose(1, 2)
  536. loss = None
  537. if labels is not None:
  538. # retrieve loss input_lengths from attention_mask
  539. attention_mask = (
  540. attention_mask if attention_mask is not None else torch.ones_like(input_features, dtype=torch.long)
  541. )
  542. input_lengths = self._get_subsampling_output_length(attention_mask.sum(-1))
  543. # assuming that padded tokens are filled with -100
  544. # when not being attended to
  545. labels_mask = labels != self.config.pad_token_id
  546. target_lengths = labels_mask.sum(-1)
  547. flattened_targets = labels.masked_select(labels_mask)
  548. # ctc_loss doesn't support fp16
  549. log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
  550. with torch.backends.cudnn.flags(enabled=False):
  551. loss = nn.functional.ctc_loss(
  552. log_probs,
  553. flattened_targets,
  554. input_lengths,
  555. target_lengths,
  556. blank=self.config.pad_token_id,
  557. reduction=self.config.ctc_loss_reduction,
  558. zero_infinity=self.config.ctc_zero_infinity,
  559. )
  560. return CausalLMOutput(
  561. loss=loss,
  562. logits=logits,
  563. hidden_states=encoder_outputs.hidden_states,
  564. attentions=encoder_outputs.attentions,
  565. )
  566. @torch.no_grad()
  567. def generate(
  568. self,
  569. input_features: torch.Tensor,
  570. attention_mask: Optional[torch.Tensor] = None,
  571. return_dict_in_generate: bool = False,
  572. **kwargs: Unpack[TransformersKwargs],
  573. ) -> Union[ParakeetGenerateOutput, torch.LongTensor]:
  574. r"""
  575. Example:
  576. ```python
  577. >>> from transformers import AutoProcessor, ParakeetForCTC
  578. >>> from datasets import load_dataset, Audio
  579. >>> model_id = "nvidia/parakeet-ctc-1.1b"
  580. >>> processor = AutoProcessor.from_pretrained(model_id)
  581. >>> model = ParakeetForCTC.from_pretrained(model_id)
  582. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  583. >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))
  584. >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
  585. >>> predicted_ids = model.generate(**inputs)
  586. >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
  587. >>> print(transcription)
  588. ```
  589. """
  590. kwargs["return_dict"] = True
  591. outputs: CausalLMOutput = self.forward(
  592. input_features=input_features,
  593. attention_mask=attention_mask,
  594. **kwargs,
  595. )
  596. # greedy decoding
  597. sequences = outputs.logits.argmax(dim=-1)
  598. # mask out padded tokens
  599. if attention_mask is not None:
  600. attention_mask = self._get_output_attention_mask(attention_mask, target_length=sequences.shape[1])
  601. sequences[~attention_mask] = self.config.pad_token_id
  602. if return_dict_in_generate:
  603. return ParakeetGenerateOutput(
  604. sequences=sequences,
  605. logits=outputs.logits,
  606. attentions=outputs.attentions,
  607. hidden_states=outputs.hidden_states,
  608. )
  609. return sequences
  610. __all__ = ["ParakeetForCTC", "ParakeetEncoder", "ParakeetPreTrainedModel"]