modular_sew.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. # coding=utf-8
  2. # Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch SEW model."""
  16. import math
  17. import warnings
  18. from typing import Optional, Union
  19. import torch
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  23. from ...integrations.fsdp import is_fsdp_managed_module
  24. from ...modeling_outputs import BaseModelOutput
  25. from ...modeling_utils import PreTrainedModel
  26. from ...utils import auto_docstring
  27. from ..wav2vec2.modeling_wav2vec2 import (
  28. Wav2Vec2Attention,
  29. Wav2Vec2EncoderLayer,
  30. Wav2Vec2FeatureEncoder,
  31. Wav2Vec2FeedForward,
  32. Wav2Vec2ForCTC,
  33. Wav2Vec2ForSequenceClassification,
  34. Wav2Vec2GroupNormConvLayer,
  35. Wav2Vec2LayerNormConvLayer,
  36. Wav2Vec2NoLayerNormConvLayer,
  37. Wav2Vec2SamePadLayer,
  38. _compute_mask_indices,
  39. )
  40. from .configuration_sew import SEWConfig
  41. _HIDDEN_STATES_START_POSITION = 1
  42. class SEWNoLayerNormConvLayer(Wav2Vec2NoLayerNormConvLayer):
  43. pass
  44. class SEWLayerNormConvLayer(Wav2Vec2LayerNormConvLayer):
  45. pass
  46. class SEWGroupNormConvLayer(Wav2Vec2GroupNormConvLayer):
  47. pass
  48. class SEWPositionalConvEmbedding(nn.Module):
  49. def __init__(self, config):
  50. super().__init__()
  51. self.conv = nn.Conv1d(
  52. config.hidden_size,
  53. config.hidden_size,
  54. kernel_size=config.num_conv_pos_embeddings,
  55. padding=config.num_conv_pos_embeddings // 2,
  56. groups=config.num_conv_pos_embedding_groups,
  57. stride=config.squeeze_factor,
  58. )
  59. weight_norm = nn.utils.weight_norm
  60. if hasattr(nn.utils.parametrizations, "weight_norm"):
  61. weight_norm = nn.utils.parametrizations.weight_norm
  62. if is_deepspeed_zero3_enabled():
  63. import deepspeed
  64. with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
  65. self.conv = weight_norm(self.conv, name="weight", dim=2)
  66. if hasattr(self.conv, "parametrizations"):
  67. weight_g = self.conv.parametrizations.weight.original0
  68. weight_v = self.conv.parametrizations.weight.original1
  69. else:
  70. weight_g = self.conv.weight_g
  71. weight_v = self.conv.weight_v
  72. deepspeed.zero.register_external_parameter(self, weight_v)
  73. deepspeed.zero.register_external_parameter(self, weight_g)
  74. else:
  75. self.conv = weight_norm(self.conv, name="weight", dim=2)
  76. self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings)
  77. self.activation = ACT2FN[config.feat_extract_activation]
  78. def forward(self, hidden_states):
  79. hidden_states = self.conv(hidden_states)
  80. hidden_states = self.padding(hidden_states)
  81. hidden_states = self.activation(hidden_states)
  82. return hidden_states
  83. class SEWSamePadLayer(Wav2Vec2SamePadLayer):
  84. pass
  85. class SEWUpsampling(nn.Module):
  86. def __init__(self, config):
  87. super().__init__()
  88. self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor)
  89. self.activation = ACT2FN[config.feat_extract_activation]
  90. self.squeeze_factor = config.squeeze_factor
  91. def forward(self, hidden_states):
  92. hidden_states = self.projection(hidden_states)
  93. hidden_states = self.activation(hidden_states)
  94. if self.squeeze_factor > 1:
  95. # transform embedding channels to sequence length
  96. bsz, src_len, src_embed_dim = hidden_states.size()
  97. tgt_len = src_len * self.squeeze_factor
  98. tgt_embed_dim = src_embed_dim // self.squeeze_factor
  99. hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim)
  100. hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim)
  101. return hidden_states
  102. class SEWFeatureEncoder(Wav2Vec2FeatureEncoder):
  103. pass
  104. class SEWFeatureExtractor(SEWFeatureEncoder):
  105. def __init__(self, config):
  106. super().__init__(config)
  107. warnings.warn(
  108. f"The class `{self.__class__.__name__}` has been depreciated "
  109. "and will be removed in Transformers v5. "
  110. f"Use `{self.__class__.__bases__[0].__name__}` instead.",
  111. FutureWarning,
  112. )
  113. class SEWAttention(Wav2Vec2Attention):
  114. pass
  115. class SEWFeedForward(Wav2Vec2FeedForward):
  116. pass
  117. class SEWEncoderLayer(Wav2Vec2EncoderLayer):
  118. pass
  119. class SEWEncoder(nn.Module):
  120. def __init__(self, config):
  121. super().__init__()
  122. self.config = config
  123. self.pos_conv_embed = SEWPositionalConvEmbedding(config)
  124. self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor)
  125. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  126. self.dropout = nn.Dropout(config.hidden_dropout)
  127. self.layers = nn.ModuleList([SEWEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  128. self.upsample = SEWUpsampling(config)
  129. self.gradient_checkpointing = False
  130. def forward(
  131. self,
  132. hidden_states,
  133. attention_mask=None,
  134. output_attentions=False,
  135. output_hidden_states=False,
  136. return_dict=True,
  137. ):
  138. all_hidden_states = () if output_hidden_states else None
  139. all_self_attentions = () if output_attentions else None
  140. if attention_mask is not None:
  141. expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
  142. if self.config._attn_implementation == "flash_attention_2":
  143. # make sure padded tokens output 0
  144. hidden_states[~expand_attention_mask] = 0.0
  145. # 2d mask is passed through the layers
  146. attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
  147. else:
  148. # make sure padded tokens output 0
  149. hidden_states[~expand_attention_mask] = 0.0
  150. input_lengths = (attention_mask.long()).sum(-1)
  151. # apply pooling formula to get real output_lengths
  152. output_lengths = input_lengths // self.config.squeeze_factor
  153. max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor
  154. attention_ids = (
  155. torch.arange(0, max_encoder_length, device=output_lengths.device)
  156. .view(1, -1)
  157. .expand(output_lengths.shape[0], -1)
  158. )
  159. attention_mask = (attention_ids < output_lengths.view(-1, 1)).long()
  160. # extend attention_mask
  161. attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
  162. attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
  163. attention_mask = attention_mask.expand(
  164. attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
  165. )
  166. n_input_timesteps = hidden_states.shape[1]
  167. hidden_states = hidden_states.transpose(1, 2)
  168. position_embeddings = self.pos_conv_embed(hidden_states)
  169. pooled_hidden_states = self.pool(hidden_states)
  170. min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1))
  171. hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length]
  172. hidden_states = hidden_states.transpose(1, 2)
  173. hidden_states = self.layer_norm(hidden_states)
  174. hidden_states = self.dropout(hidden_states)
  175. synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
  176. for layer in self.layers:
  177. if output_hidden_states:
  178. all_hidden_states = all_hidden_states + (hidden_states,)
  179. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  180. dropout_probability = torch.rand([])
  181. skip_the_layer = self.training and dropout_probability < self.config.layerdrop
  182. if not skip_the_layer or synced_gpus:
  183. # under fsdp or deepspeed zero3 all gpus must run in sync
  184. layer_outputs = layer(
  185. hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
  186. )
  187. hidden_states = layer_outputs[0]
  188. if skip_the_layer:
  189. layer_outputs = (None, None)
  190. if output_attentions:
  191. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  192. if output_hidden_states:
  193. all_hidden_states = all_hidden_states + (hidden_states,)
  194. hidden_states = self.upsample(hidden_states)
  195. if hidden_states.shape[1] < n_input_timesteps:
  196. hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1]))
  197. if not return_dict:
  198. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  199. return BaseModelOutput(
  200. last_hidden_state=hidden_states,
  201. hidden_states=all_hidden_states,
  202. attentions=all_self_attentions,
  203. )
  204. @auto_docstring
  205. class SEWPreTrainedModel(PreTrainedModel):
  206. config: SEWConfig
  207. base_model_prefix = "sew"
  208. main_input_name = "input_values"
  209. supports_gradient_checkpointing = True
  210. _supports_flash_attn = True
  211. _supports_sdpa = True
  212. _supports_flex_attn = False # needs a proper look into the mask creation
  213. def _init_weights(self, module):
  214. """Initialize the weights"""
  215. if isinstance(module, SEWPositionalConvEmbedding):
  216. nn.init.normal_(
  217. module.conv.weight,
  218. mean=0,
  219. std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
  220. )
  221. nn.init.constant_(module.conv.bias, 0)
  222. elif isinstance(module, nn.Linear):
  223. # Slightly different from the TF version which uses truncated_normal for initialization
  224. # cf https://github.com/pytorch/pytorch/pull/5617
  225. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  226. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  227. module.bias.data.zero_()
  228. module.weight.data.fill_(1.0)
  229. elif isinstance(module, nn.Conv1d):
  230. if is_deepspeed_zero3_enabled():
  231. import deepspeed
  232. if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
  233. with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
  234. nn.init.kaiming_normal_(module.weight.data)
  235. else:
  236. with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
  237. nn.init.kaiming_normal_(module.weight.data)
  238. else:
  239. nn.init.kaiming_normal_(module.weight.data)
  240. if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
  241. module.bias.data.zero_()
  242. def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
  243. """
  244. Computes the output length of the convolutional layers
  245. """
  246. def _conv_out_length(input_length, kernel_size, stride):
  247. # 1D convolutional layer output length formula taken
  248. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  249. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  250. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  251. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  252. return input_lengths
  253. def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
  254. output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  255. batch_size = attention_mask.shape[0]
  256. attention_mask = torch.zeros(
  257. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  258. )
  259. # these two operations makes sure that all values before the output lengths idxs are attended to
  260. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  261. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  262. return attention_mask
  263. @auto_docstring
  264. class SEWModel(SEWPreTrainedModel):
  265. def __init__(self, config: SEWConfig):
  266. super().__init__(config)
  267. self.config = config
  268. self.feature_extractor = SEWFeatureEncoder(config)
  269. self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
  270. self.project_features = config.conv_dim[-1] != config.hidden_size
  271. if self.project_features:
  272. self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
  273. self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
  274. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  275. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  276. self.encoder = SEWEncoder(config)
  277. # Initialize weights and apply final processing
  278. self.post_init()
  279. # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
  280. def _mask_hidden_states(
  281. self,
  282. hidden_states: torch.FloatTensor,
  283. mask_time_indices: Optional[torch.FloatTensor] = None,
  284. attention_mask: Optional[torch.LongTensor] = None,
  285. ):
  286. """
  287. Masks extracted features along time axis and/or along feature axis according to
  288. [SpecAugment](https://huggingface.co/papers/1904.08779).
  289. """
  290. # `config.apply_spec_augment` can set masking to False
  291. if not getattr(self.config, "apply_spec_augment", True):
  292. return hidden_states
  293. # generate indices & apply SpecAugment along time axis
  294. batch_size, sequence_length, hidden_size = hidden_states.size()
  295. if mask_time_indices is not None:
  296. # apply SpecAugment along time axis with given mask_time_indices
  297. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  298. elif self.config.mask_time_prob > 0 and self.training:
  299. mask_time_indices = _compute_mask_indices(
  300. (batch_size, sequence_length),
  301. mask_prob=self.config.mask_time_prob,
  302. mask_length=self.config.mask_time_length,
  303. attention_mask=attention_mask,
  304. min_masks=self.config.mask_time_min_masks,
  305. )
  306. mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
  307. hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
  308. if self.config.mask_feature_prob > 0 and self.training:
  309. # generate indices & apply SpecAugment along feature axis
  310. mask_feature_indices = _compute_mask_indices(
  311. (batch_size, hidden_size),
  312. mask_prob=self.config.mask_feature_prob,
  313. mask_length=self.config.mask_feature_length,
  314. min_masks=self.config.mask_feature_min_masks,
  315. )
  316. mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
  317. mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
  318. hidden_states[mask_feature_indices] = 0
  319. return hidden_states
  320. @auto_docstring
  321. def forward(
  322. self,
  323. input_values: Optional[torch.Tensor],
  324. attention_mask: Optional[torch.Tensor] = None,
  325. mask_time_indices: Optional[torch.FloatTensor] = None,
  326. output_attentions: Optional[bool] = None,
  327. output_hidden_states: Optional[bool] = None,
  328. return_dict: Optional[bool] = None,
  329. ) -> Union[tuple, BaseModelOutput]:
  330. r"""
  331. mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  332. Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
  333. masked extracted features in *config.proj_codevector_dim* space.
  334. """
  335. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  336. output_hidden_states = (
  337. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  338. )
  339. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  340. extract_features = self.feature_extractor(input_values)
  341. extract_features = extract_features.transpose(1, 2)
  342. extract_features = self.layer_norm(extract_features)
  343. if self.project_features:
  344. extract_features = self.feature_projection(extract_features)
  345. hidden_states = self.feature_dropout(extract_features)
  346. if attention_mask is not None:
  347. # compute reduced attention_mask corresponding to feature vectors
  348. attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
  349. hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
  350. encoder_outputs = self.encoder(
  351. hidden_states,
  352. attention_mask=attention_mask,
  353. output_attentions=output_attentions,
  354. output_hidden_states=output_hidden_states,
  355. return_dict=return_dict,
  356. )
  357. hidden_states = encoder_outputs[0]
  358. if not return_dict:
  359. return (hidden_states,) + encoder_outputs[1:]
  360. return BaseModelOutput(
  361. last_hidden_state=hidden_states,
  362. hidden_states=encoder_outputs.hidden_states,
  363. attentions=encoder_outputs.attentions,
  364. )
  365. class SEWForCTC(Wav2Vec2ForCTC):
  366. pass
  367. class SEWForSequenceClassification(Wav2Vec2ForSequenceClassification):
  368. pass
  369. __all__ = ["SEWForCTC", "SEWForSequenceClassification", "SEWModel", "SEWPreTrainedModel"]