modular_hubert.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # coding=utf-8
  2. # Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Hubert model."""
  16. from typing import Optional, Union
  17. import torch
  18. import torch.nn as nn
  19. from ...activations import ACT2FN
  20. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  21. from ...modeling_outputs import BaseModelOutput
  22. from ...modeling_utils import PreTrainedModel
  23. from ...utils import auto_docstring
  24. from ..wav2vec2.modeling_wav2vec2 import (
  25. Wav2Vec2Encoder,
  26. Wav2Vec2EncoderStableLayerNorm,
  27. Wav2Vec2FeatureEncoder,
  28. Wav2Vec2ForCTC,
  29. Wav2Vec2ForSequenceClassification,
  30. Wav2Vec2Model,
  31. Wav2Vec2SamePadLayer,
  32. )
  33. from .configuration_hubert import HubertConfig
  34. _HIDDEN_STATES_START_POSITION = 1
  35. class HubertPositionalConvEmbedding(nn.Module):
  36. def __init__(self, config):
  37. super().__init__()
  38. self.conv = nn.Conv1d(
  39. config.hidden_size,
  40. config.hidden_size,
  41. kernel_size=config.num_conv_pos_embeddings,
  42. padding=config.num_conv_pos_embeddings // 2,
  43. groups=config.num_conv_pos_embedding_groups,
  44. )
  45. self.batch_norm = None
  46. if config.conv_pos_batch_norm:
  47. self.batch_norm = nn.BatchNorm1d(config.hidden_size)
  48. else:
  49. weight_norm = nn.utils.weight_norm
  50. if hasattr(nn.utils.parametrizations, "weight_norm"):
  51. weight_norm = nn.utils.parametrizations.weight_norm
  52. if is_deepspeed_zero3_enabled():
  53. import deepspeed
  54. with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
  55. self.conv = weight_norm(self.conv, name="weight", dim=2)
  56. if hasattr(self.conv, "parametrizations"):
  57. weight_g = self.conv.parametrizations.weight.original0
  58. weight_v = self.conv.parametrizations.weight.original1
  59. else:
  60. weight_g = self.conv.weight_g
  61. weight_v = self.conv.weight_v
  62. deepspeed.zero.register_external_parameter(self, weight_v)
  63. deepspeed.zero.register_external_parameter(self, weight_g)
  64. else:
  65. self.conv = weight_norm(self.conv, name="weight", dim=2)
  66. self.padding = HubertSamePadLayer(config.num_conv_pos_embeddings)
  67. self.activation = ACT2FN[config.feat_extract_activation]
  68. def forward(self, hidden_states):
  69. hidden_states = hidden_states.transpose(1, 2)
  70. if self.batch_norm is not None:
  71. hidden_states = self.batch_norm(hidden_states)
  72. hidden_states = self.conv(hidden_states)
  73. hidden_states = self.padding(hidden_states)
  74. hidden_states = self.activation(hidden_states)
  75. hidden_states = hidden_states.transpose(1, 2)
  76. return hidden_states
  77. class HubertSamePadLayer(Wav2Vec2SamePadLayer):
  78. pass
  79. class HubertFeatureEncoder(Wav2Vec2FeatureEncoder):
  80. pass
  81. class HubertFeatureProjection(nn.Module):
  82. def __init__(self, config):
  83. super().__init__()
  84. self.feat_proj_layer_norm = config.feat_proj_layer_norm
  85. if self.feat_proj_layer_norm:
  86. self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
  87. self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
  88. self.dropout = nn.Dropout(config.feat_proj_dropout)
  89. def forward(self, hidden_states):
  90. # non-projected hidden states are needed for quantization
  91. if self.feat_proj_layer_norm:
  92. hidden_states = self.layer_norm(hidden_states)
  93. hidden_states = self.projection(hidden_states)
  94. hidden_states = self.dropout(hidden_states)
  95. return hidden_states
  96. class HubertEncoder(Wav2Vec2Encoder):
  97. pass
  98. class HubertEncoderStableLayerNorm(Wav2Vec2EncoderStableLayerNorm):
  99. pass
  100. @auto_docstring
  101. class HubertPreTrainedModel(PreTrainedModel):
  102. config: HubertConfig
  103. base_model_prefix = "hubert"
  104. main_input_name = "input_values"
  105. supports_gradient_checkpointing = True
  106. _supports_flash_attn = True
  107. _supports_sdpa = True
  108. _supports_flex_attn = True
  109. def _init_weights(self, module):
  110. """Initialize the weights"""
  111. if isinstance(module, nn.Linear):
  112. # Slightly different from the TF version which uses truncated_normal for initialization
  113. # cf https://github.com/pytorch/pytorch/pull/5617
  114. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  115. if module.bias is not None:
  116. module.bias.data.zero_()
  117. elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm1d)):
  118. module.bias.data.zero_()
  119. module.weight.data.fill_(1.0)
  120. elif isinstance(module, nn.Conv1d):
  121. if is_deepspeed_zero3_enabled():
  122. import deepspeed
  123. if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
  124. with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
  125. nn.init.kaiming_normal_(module.weight.data)
  126. else:
  127. with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
  128. nn.init.kaiming_normal_(module.weight.data)
  129. else:
  130. nn.init.kaiming_normal_(module.weight.data)
  131. if module.bias is not None:
  132. module.bias.data.zero_()
  133. elif isinstance(module, HubertModel):
  134. if hasattr(module, "masked_spec_embed"):
  135. module.masked_spec_embed.data.uniform_()
  136. elif isinstance(module, HubertForSequenceClassification):
  137. if hasattr(module, "layer_weights"):
  138. module.layer_weights.data.fill_(1.0 / (self.config.num_hidden_layers + 1))
  139. def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
  140. """
  141. Computes the output length of the convolutional layers
  142. """
  143. def _conv_out_length(input_length, kernel_size, stride):
  144. # 1D convolutional layer output length formula taken
  145. # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
  146. return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
  147. for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
  148. input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
  149. return input_lengths
  150. def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
  151. output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
  152. batch_size = attention_mask.shape[0]
  153. attention_mask = torch.zeros(
  154. (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
  155. )
  156. # these two operations makes sure that all values before the output lengths idxs are attended to
  157. attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
  158. attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
  159. return attention_mask
  160. class HubertModel(Wav2Vec2Model, HubertPreTrainedModel):
  161. def __init__(self, config: HubertConfig):
  162. super().__init__(config)
  163. self.config = config
  164. self.feature_extractor = HubertFeatureEncoder(config)
  165. self.feature_projection = HubertFeatureProjection(config)
  166. if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
  167. self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
  168. if config.do_stable_layer_norm:
  169. self.encoder = HubertEncoderStableLayerNorm(config)
  170. else:
  171. self.encoder = HubertEncoder(config)
  172. # Initialize weights and apply final processing
  173. self.post_init()
  174. del self.adapter
  175. def freeze_feature_extractor(self):
  176. raise AttributeError("Not needed for Hubert")
  177. def freeze_feature_encoder(self):
  178. raise AttributeError("Not needed for Hubert")
  179. def forward(
  180. self,
  181. input_values: Optional[torch.Tensor],
  182. attention_mask: Optional[torch.Tensor] = None,
  183. mask_time_indices: Optional[torch.FloatTensor] = None,
  184. output_attentions: Optional[bool] = None,
  185. output_hidden_states: Optional[bool] = None,
  186. return_dict: Optional[bool] = None,
  187. ) -> Union[tuple, BaseModelOutput]:
  188. r"""
  189. mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
  190. Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
  191. masked extracted features in *config.proj_codevector_dim* space.
  192. Example:
  193. ```python
  194. >>> from transformers import AutoProcessor, HubertModel
  195. >>> from datasets import load_dataset
  196. >>> processor = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
  197. >>> model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
  198. >>> def map_to_array(example):
  199. ... example["speech"] = example["audio"]["array"]
  200. ... return example
  201. >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
  202. >>> ds = ds.map(map_to_array)
  203. >>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
  204. >>> hidden_states = model(input_values).last_hidden_state
  205. ```"""
  206. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  207. output_hidden_states = (
  208. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  209. )
  210. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  211. extract_features = self.feature_extractor(input_values)
  212. extract_features = extract_features.transpose(1, 2)
  213. if attention_mask is not None:
  214. # compute reduced attention_mask corresponding to feature vectors
  215. attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
  216. hidden_states = self.feature_projection(extract_features)
  217. hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
  218. encoder_outputs = self.encoder(
  219. hidden_states,
  220. attention_mask=attention_mask,
  221. output_attentions=output_attentions,
  222. output_hidden_states=output_hidden_states,
  223. return_dict=return_dict,
  224. )
  225. hidden_states = encoder_outputs[0]
  226. if not return_dict:
  227. return (hidden_states,) + encoder_outputs[1:]
  228. return BaseModelOutput(
  229. last_hidden_state=hidden_states,
  230. hidden_states=encoder_outputs.hidden_states,
  231. attentions=encoder_outputs.attentions,
  232. )
  233. class HubertForCTC(Wav2Vec2ForCTC):
  234. pass
  235. class HubertForSequenceClassification(Wav2Vec2ForSequenceClassification):
  236. pass
  237. __all__ = ["HubertForCTC", "HubertForSequenceClassification", "HubertModel", "HubertPreTrainedModel"]