processing_wav2vec2.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # coding=utf-8
  2. # Copyright 2021 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Speech processor class for Wav2Vec2
  17. """
  18. import warnings
  19. from contextlib import contextmanager
  20. from typing import Optional, Union
  21. from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
  22. from ...tokenization_utils_base import AudioInput, PreTokenizedInput, TextInput
  23. from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
  24. from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
  25. class Wav2Vec2ProcessorKwargs(ProcessingKwargs, total=False):
  26. _defaults = {}
  27. class Wav2Vec2Processor(ProcessorMixin):
  28. r"""
  29. Constructs a Wav2Vec2 processor which wraps a Wav2Vec2 feature extractor and a Wav2Vec2 CTC tokenizer into a single
  30. processor.
  31. [`Wav2Vec2Processor`] offers all the functionalities of [`Wav2Vec2FeatureExtractor`] and [`PreTrainedTokenizer`].
  32. See the docstring of [`~Wav2Vec2Processor.__call__`] and [`~Wav2Vec2Processor.decode`] for more information.
  33. Args:
  34. feature_extractor (`Wav2Vec2FeatureExtractor`):
  35. An instance of [`Wav2Vec2FeatureExtractor`]. The feature extractor is a required input.
  36. tokenizer ([`PreTrainedTokenizer`]):
  37. An instance of [`PreTrainedTokenizer`]. The tokenizer is a required input.
  38. """
  39. feature_extractor_class = "Wav2Vec2FeatureExtractor"
  40. tokenizer_class = "AutoTokenizer"
  41. def __init__(self, feature_extractor, tokenizer):
  42. super().__init__(feature_extractor, tokenizer)
  43. self.current_processor = self.feature_extractor
  44. self._in_target_context_manager = False
  45. @classmethod
  46. def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
  47. try:
  48. return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
  49. except (OSError, ValueError):
  50. warnings.warn(
  51. f"Loading a tokenizer inside {cls.__name__} from a config that does not"
  52. " include a `tokenizer_class` attribute is deprecated and will be "
  53. "removed in v5. Please add `'tokenizer_class': 'Wav2Vec2CTCTokenizer'`"
  54. " attribute to either your `config.json` or `tokenizer_config.json` "
  55. "file to suppress this warning: ",
  56. FutureWarning,
  57. )
  58. feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
  59. tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
  60. return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
  61. def __call__(
  62. self,
  63. audio: Optional[AudioInput] = None,
  64. text: Optional[Union[str, list[str], TextInput, PreTokenizedInput]] = None,
  65. images=None,
  66. videos=None,
  67. **kwargs: Unpack[Wav2Vec2ProcessorKwargs],
  68. ):
  69. """
  70. This method forwards all arguments to [`Wav2Vec2FeatureExtractor.__call__`] and/or
  71. [`PreTrainedTokenizer.__call__`] depending on the input modality and returns their outputs. If both modalities are passed, [`Wav2Vec2FeatureExtractor.__call__`] and [`PreTrainedTokenizer.__call__`] are called.
  72. Args:
  73. audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*):
  74. An audio input is passed to [`Wav2Vec2FeatureExtractor.__call__`].
  75. text (`str`, `List[str]`, *optional*):
  76. A text input is passed to [`PreTrainedTokenizer.__call__`].
  77. Returns:
  78. This method returns the results of each `call` method. If both are used, the output is a dictionary containing the results of both.
  79. """
  80. if "raw_speech" in kwargs:
  81. warnings.warn("Using `raw_speech` as a keyword argument is deprecated. Use `audio` instead.")
  82. audio = kwargs.pop("raw_speech")
  83. if audio is None and text is None:
  84. raise ValueError("You need to specify either an `audio` or `text` input to process.")
  85. output_kwargs = self._merge_kwargs(
  86. Wav2Vec2ProcessorKwargs,
  87. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  88. **kwargs,
  89. )
  90. # For backward compatibility
  91. if self._in_target_context_manager:
  92. return self.current_processor(
  93. audio,
  94. **output_kwargs["audio_kwargs"],
  95. **output_kwargs["text_kwargs"],
  96. **output_kwargs["common_kwargs"],
  97. )
  98. if audio is not None:
  99. inputs = self.feature_extractor(audio, **output_kwargs["audio_kwargs"])
  100. if text is not None:
  101. encodings = self.tokenizer(text, **output_kwargs["text_kwargs"])
  102. if text is None:
  103. return inputs
  104. elif audio is None:
  105. return encodings
  106. else:
  107. inputs["labels"] = encodings["input_ids"]
  108. return inputs
  109. def pad(self, *args, **kwargs):
  110. """
  111. This method operates on batches of extracted features and/or tokenized text. It forwards all arguments to
  112. [`Wav2Vec2FeatureExtractor.pad`] and/or [`PreTrainedTokenizer.pad`] depending on the input modality and returns their outputs. If both modalities are passed, [`Wav2Vec2FeatureExtractor.pad`] and [`PreTrainedTokenizer.pad`] are called.
  113. Args:
  114. input_features:
  115. When the first argument is a dictionary containing a batch of tensors, or the `input_features` argument is present, it is passed to [`Wav2Vec2FeatureExtractor.pad`].
  116. labels:
  117. When the `label` argument is present, it is passed to [`PreTrainedTokenizer.pad`].
  118. Returns:
  119. This method returns the results of each `pad` method. If both are used, the output is a dictionary containing the results of both.
  120. """
  121. # For backward compatibility
  122. if self._in_target_context_manager:
  123. return self.current_processor.pad(*args, **kwargs)
  124. input_features = kwargs.pop("input_features", None)
  125. labels = kwargs.pop("labels", None)
  126. if len(args) > 0:
  127. input_features = args[0]
  128. args = args[1:]
  129. if input_features is not None:
  130. input_features = self.feature_extractor.pad(input_features, *args, **kwargs)
  131. if labels is not None:
  132. labels = self.tokenizer.pad(labels, **kwargs)
  133. if labels is None:
  134. return input_features
  135. elif input_features is None:
  136. return labels
  137. else:
  138. input_features["labels"] = labels["input_ids"]
  139. return input_features
  140. @property
  141. def model_input_names(self):
  142. # The processor doesn't return text ids and the model seems to not need them
  143. feature_extractor_input_names = self.feature_extractor.model_input_names
  144. return feature_extractor_input_names + ["labels"]
  145. @contextmanager
  146. def as_target_processor(self):
  147. """
  148. Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning
  149. Wav2Vec2.
  150. """
  151. warnings.warn(
  152. "`as_target_processor` is deprecated and will be removed in v5 of Transformers. You can process your "
  153. "labels by using the argument `text` of the regular `__call__` method (either in the same call as "
  154. "your audio inputs, or in a separate call."
  155. )
  156. self._in_target_context_manager = True
  157. self.current_processor = self.tokenizer
  158. yield
  159. self.current_processor = self.feature_extractor
  160. self._in_target_context_manager = False
  161. __all__ = ["Wav2Vec2Processor"]