modular_unispeech.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. # coding=utf-8
  2. # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch UniSpeech model."""
  16. import math
  17. import warnings
  18. from dataclasses import dataclass
  19. from typing import Optional, Union
  20. import torch
  21. import torch.nn as nn
  22. from ...modeling_outputs import ModelOutput, Wav2Vec2BaseModelOutput
  23. from ...modeling_utils import PreTrainedModel
  24. from ...utils import auto_docstring, logging
  25. from ..wav2vec2.modeling_wav2vec2 import (
  26. Wav2Vec2Encoder,
  27. Wav2Vec2EncoderStableLayerNorm,
  28. Wav2Vec2FeatureEncoder,
  29. Wav2Vec2FeatureProjection,
  30. Wav2Vec2ForCTC,
  31. Wav2Vec2ForSequenceClassification,
  32. Wav2Vec2GumbelVectorQuantizer,
  33. Wav2Vec2Model,
  34. Wav2Vec2PositionalConvEmbedding,
  35. )
  36. from .configuration_unispeech import UniSpeechConfig
  37. logger = logging.get_logger(__name__)
  38. @dataclass
  39. @auto_docstring(
  40. custom_intro="""
  41. Output type of [`UniSpeechForPreTrainingOutput`], with potential hidden states and attentions.
  42. """
  43. )
  44. class UniSpeechForPreTrainingOutput(ModelOutput):
  45. r"""
  46. loss (*optional*, returned when model is in train mode, `torch.FloatTensor` of shape `(1,)`):
  47. Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official
  48. paper](https://huggingface.co/papers/2006.11477).
  49. projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
  50. Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked
  51. projected quantized states.
  52. projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`):
  53. Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive
  54. target vectors for contrastive loss.
  55. codevector_perplexity (`torch.FloatTensor` of shape `(1,)`):
  56. The perplexity of the codevector distribution, used to measure the diversity of the codebook.
  57. """
  58. loss: Optional[torch.FloatTensor] = None
  59. projected_states: Optional[torch.FloatTensor] = None
  60. projected_quantized_states: Optional[torch.FloatTensor] = None
  61. codevector_perplexity: Optional[torch.FloatTensor] = None
  62. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  63. attentions: Optional[tuple[torch.FloatTensor]] = None
  64. class UniSpeechPositionalConvEmbedding(Wav2Vec2PositionalConvEmbedding):
  65. pass
  66. class UniSpeechFeatureEncoder(Wav2Vec2FeatureEncoder):
  67. pass
  68. class UniSpeechFeatureProjection(Wav2Vec2FeatureProjection):
  69. pass
  70. class UniSpeechEncoder(Wav2Vec2Encoder):
  71. pass
  72. class UniSpeechEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm):
  73. pass
  74. class UniSpeechGumbelVectorQuantizer(Wav2Vec2GumbelVectorQuantizer):
  75. @staticmethod
  76. def _compute_perplexity(probs):
  77. marginal_probs = probs.mean(dim=0)
  78. perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
  79. return perplexity
  80. def forward(self, hidden_states):
  81. batch_size, sequence_length, hidden_size = hidden_states.shape
  82. # project to codevector dim
  83. hidden_states = self.weight_proj(hidden_states)
  84. hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1)
  85. if self.training:
  86. # sample code vector probs via gumbel in differentiateable way
  87. codevector_probs = nn.functional.gumbel_softmax(
  88. hidden_states.float(), tau=self.temperature, hard=True
  89. ).type_as(hidden_states)
  90. # compute perplexity
  91. codevector_soft_dist = torch.softmax(
  92. hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1
  93. )
  94. perplexity = self._compute_perplexity(codevector_soft_dist)
  95. else:
  96. # take argmax in non-differentiable way
  97. # comptute hard codevector distribution (one hot)
  98. codevector_idx = hidden_states.argmax(dim=-1)
  99. codevector_probs = hidden_states.new_zeros(*hidden_states.shape).scatter_(
  100. -1, codevector_idx.view(-1, 1), 1.0
  101. )
  102. codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1)
  103. perplexity = self._compute_perplexity(codevector_probs)
  104. codevector_probs = codevector_probs.view(batch_size * sequence_length, -1)
  105. # use probs to retrieve codevectors
  106. codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors
  107. codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1)
  108. codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1)
  109. return codevectors, perplexity
  110. @auto_docstring
  111. class UniSpeechPreTrainedModel(PreTrainedModel):
  112. config: UniSpeechConfig
  113. base_model_prefix = "unispeech"
  114. main_input_name = "input_values"
  115. supports_gradient_checkpointing = True
  116. _supports_flash_attn = True
  117. _supports_sdpa = True
  118. _supports_flex_attn = True
  119. def _init_weights(self, module):
  120. """Initialize the weights"""
  121. # gumbel softmax requires special init
  122. if isinstance(module, UniSpeechGumbelVectorQuantizer):
  123. module.weight_proj.weight.data.normal_(mean=0.0, std=1)
  124. module.weight_proj.bias.data.zero_()
  125. nn.init.uniform_(module.codevectors)
  126. elif isinstance(module, UniSpeechPositionalConvEmbedding):
  127. nn.init.normal_(
  128. module.conv.weight,
  129. mean=0,
  130. std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
  131. )
  132. nn.init.constant_(module.conv.bias, 0)
  133. elif isinstance(module, UniSpeechFeatureProjection):
  134. k = math.sqrt(1 / module.projection.in_features)
  135. nn.init.uniform_(module.projection.weight, a=-k, b=k)
  136. nn.init.uniform_(module.projection.bias, a=-k, b=k)
  137. elif isinstance(module, nn.Linear):
  138. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  139. if module.bias is not None:
  140. module.bias.data.zero_()
  141. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
  142. module.bias.data.zero_()
  143. module.weight.data.fill_(1.0)
  144. elif isinstance(module, nn.Conv1d):
  145. nn.init.kaiming_normal_(module.weight)
  146. if module.bias is not None:
  147. k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
  148. nn.init.uniform_(module.bias, a=-k, b=k)
  149. def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
  150. """
  151. Computes the output length of the convolutional layers
  152. """
  153. def _conv_out_length(input_length, kernel_size, stride):
  154. # 1D convolutional layer output length formula taken
  155. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  156. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  157. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  158. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  159. return input_lengths
  160. def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
  161. # Effectively attention_mask.sum(-1), but not inplace to be able to run
  162. # on inference mode.
  163. non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
  164. output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
  165. batch_size = attention_mask.shape[0]
  166. attention_mask = torch.zeros(
  167. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  168. )
  169. # these two operations makes sure that all values before the output lengths idxs are attended to
  170. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  171. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  172. return attention_mask
  173. UniSpeechBaseModelOutput = Wav2Vec2BaseModelOutput
  174. class UniSpeechModel(UniSpeechPreTrainedModel, Wav2Vec2Model):
  175. def __init__(self, config: UniSpeechConfig):
  176. UniSpeechPreTrainedModel.__init__(self, config)
  177. self.config = config
  178. self.feature_extractor = UniSpeechFeatureEncoder(config)
  179. self.feature_projection = UniSpeechFeatureProjection(config)
  180. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  181. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  182. if config.do_stable_layer_norm:
  183. self.encoder = UniSpeechEncoderStableLayerNorm(config)
  184. else:
  185. self.encoder = UniSpeechEncoder(config)
  186. # Initialize weights and apply final processing
  187. self.post_init()
  188. def freeze_feature_extractor(self):
  189. raise AttributeError("Not needed for UniSpeech")
  190. def freeze_feature_encoder(self):
  191. raise AttributeError("Not needed for UniSpeech")
  192. def forward(
  193. self,
  194. input_values: Optional[torch.Tensor],
  195. attention_mask: Optional[torch.Tensor] = None,
  196. mask_time_indices: Optional[torch.FloatTensor] = None,
  197. output_attentions: Optional[bool] = None,
  198. output_hidden_states: Optional[bool] = None,
  199. return_dict: Optional[bool] = None,
  200. ) -> Union[tuple, UniSpeechBaseModelOutput]:
  201. r"""
  202. mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  203. Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
  204. masked extracted features in *config.proj_codevector_dim* space.
  205. """
  206. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  207. output_hidden_states = (
  208. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  209. )
  210. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  211. extract_features = self.feature_extractor(input_values)
  212. extract_features = extract_features.transpose(1, 2)
  213. if attention_mask is not None:
  214. # compute reduced attention_mask corresponding to feature vectors
  215. attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
  216. hidden_states, extract_features = self.feature_projection(extract_features)
  217. hidden_states = self._mask_hidden_states(
  218. hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
  219. )
  220. encoder_outputs = self.encoder(
  221. hidden_states,
  222. attention_mask=attention_mask,
  223. output_attentions=output_attentions,
  224. output_hidden_states=output_hidden_states,
  225. return_dict=return_dict,
  226. )
  227. hidden_states = encoder_outputs[0]
  228. if not return_dict:
  229. return (hidden_states, extract_features) + encoder_outputs[1:]
  230. return UniSpeechBaseModelOutput(
  231. last_hidden_state=hidden_states,
  232. extract_features=extract_features,
  233. hidden_states=encoder_outputs.hidden_states,
  234. attentions=encoder_outputs.attentions,
  235. )
  236. @auto_docstring(
  237. custom_intro="""
  238. UniSpeech Model with a vector-quantization module and ctc loss for pre-training.
  239. """
  240. )
  241. class UniSpeechForPreTraining(UniSpeechPreTrainedModel):
  242. def __init__(self, config: UniSpeechConfig):
  243. super().__init__(config)
  244. self.unispeech = UniSpeechModel(config)
  245. self.dropout_features = nn.Dropout(config.feat_quantizer_dropout)
  246. self.quantizer = UniSpeechGumbelVectorQuantizer(config)
  247. self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim)
  248. self.project_hid = nn.Linear(config.proj_codevector_dim, config.hidden_size)
  249. self.ctc_proj = nn.Linear(config.hidden_size, config.num_ctc_classes)
  250. self.dropout = nn.Dropout(config.final_dropout)
  251. # Initialize weights and apply final processing
  252. self.post_init()
  253. def set_gumbel_temperature(self, temperature: int):
  254. """
  255. Set the Gumbel softmax temperature to a given value. Only necessary for training
  256. """
  257. self.quantizer.temperature = temperature
  258. def freeze_feature_extractor(self):
  259. """
  260. Calling this function will disable the gradient computation for the feature encoder so that its parameters will
  261. not be updated during training.
  262. """
  263. warnings.warn(
  264. "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5. "
  265. "Please use the equivalent `freeze_feature_encoder` method instead.",
  266. FutureWarning,
  267. )
  268. self.freeze_feature_encoder()
  269. def freeze_feature_encoder(self):
  270. """
  271. Calling this function will disable the gradient computation for the feature encoder so that its parameter will
  272. not be updated during training.
  273. """
  274. self.unispeech.feature_extractor._freeze_parameters()
  275. @staticmethod
  276. def compute_contrastive_logits(
  277. target_features: torch.FloatTensor,
  278. negative_features: torch.FloatTensor,
  279. predicted_features: torch.FloatTensor,
  280. temperature: int = 1,
  281. ):
  282. """
  283. Compute logits for contrastive loss based using cosine similarity as the distance measure between
  284. `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied.
  285. """
  286. target_features = torch.cat([target_features, negative_features], dim=0)
  287. logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1)
  288. logits = logits.type_as(target_features)
  289. # apply temperature
  290. logits = logits / temperature
  291. return logits
  292. @auto_docstring
  293. def forward(
  294. self,
  295. input_values: Optional[torch.Tensor],
  296. attention_mask: Optional[torch.Tensor] = None,
  297. output_attentions: Optional[bool] = None,
  298. output_hidden_states: Optional[bool] = None,
  299. return_dict: Optional[bool] = None,
  300. ) -> Union[tuple, UniSpeechForPreTrainingOutput]:
  301. r"""
  302. Example:
  303. ```python
  304. >>> import torch
  305. >>> from transformers import AutoFeatureExtractor, UniSpeechForPreTraining
  306. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/unispeech-large-1500h-cv")
  307. >>> model = UniSpeechForPreTraining.from_pretrained("microsoft/unispeech-large-1500h-cv")
  308. >>> # TODO: Add full pretraining example
  309. ```"""
  310. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  311. outputs = self.unispeech(
  312. input_values,
  313. attention_mask=attention_mask,
  314. output_attentions=output_attentions,
  315. output_hidden_states=output_hidden_states,
  316. return_dict=return_dict,
  317. )
  318. transformer_features = outputs[0]
  319. # quantize all (unmasked) extracted features and project to final vq dim
  320. extract_features = self.dropout_features(outputs[1])
  321. quantized_features, codevector_perplexity = self.quantizer(extract_features)
  322. # project quantized features twice
  323. quantized_features = self.project_q(quantized_features.to(self.project_q.weight.dtype))
  324. quantized_features = self.project_hid(quantized_features)
  325. prob_replace_matrix = torch.empty(transformer_features.size(0), transformer_features.size(1)).fill_(
  326. self.config.replace_prob
  327. )
  328. prob_replace_matrix = prob_replace_matrix.transpose(0, 1)
  329. sampled_replace_matrix = torch.bernoulli(prob_replace_matrix).bool().to(transformer_features.device)
  330. sampled_replace_matrix = sampled_replace_matrix.transpose(0, 1)
  331. sampled_replace_matrix = sampled_replace_matrix.unsqueeze(-1)
  332. logits = transformer_features.masked_fill(sampled_replace_matrix, 0.0) + (
  333. quantized_features.masked_fill(~sampled_replace_matrix, 0.0)
  334. )
  335. # project to ctc units
  336. logits = self.dropout(logits)
  337. logits = self.ctc_proj(logits)
  338. # TODO(PVP) - add negative sampling & loss computation
  339. loss = None
  340. if not return_dict:
  341. if loss is not None:
  342. return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
  343. return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:]
  344. return UniSpeechForPreTrainingOutput(
  345. loss=loss,
  346. projected_states=transformer_features,
  347. projected_quantized_states=quantized_features,
  348. codevector_perplexity=codevector_perplexity,
  349. hidden_states=outputs.hidden_states,
  350. attentions=outputs.attentions,
  351. )
  352. class UniSpeechForCTC(Wav2Vec2ForCTC):
  353. pass
  354. class UniSpeechForSequenceClassification(Wav2Vec2ForSequenceClassification):
  355. pass
  356. __all__ = [
  357. "UniSpeechForCTC",
  358. "UniSpeechForPreTraining",
  359. "UniSpeechForSequenceClassification",
  360. "UniSpeechModel",
  361. "UniSpeechPreTrainedModel",
  362. ]