automatic_speech_recognition.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678
  1. # Copyright 2021 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections import defaultdict
  15. from typing import TYPE_CHECKING, Any, Optional, Union
  16. import numpy as np
  17. import requests
  18. from ..generation import GenerationConfig
  19. from ..tokenization_utils import PreTrainedTokenizer
  20. from ..utils import is_torch_available, is_torchaudio_available, is_torchcodec_available, logging
  21. from .audio_utils import ffmpeg_read
  22. from .base import ChunkPipeline
  23. if TYPE_CHECKING:
  24. from pyctcdecode import BeamSearchDecoderCTC
  25. from ..feature_extraction_sequence_utils import SequenceFeatureExtractor
  26. from ..modeling_utils import PreTrainedModel
  27. logger = logging.get_logger(__name__)
  28. if is_torch_available():
  29. import torch
  30. from ..models.auto.modeling_auto import MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
  31. def rescale_stride(stride, ratio):
  32. """
  33. Rescales the stride values from audio space to tokens/logits space.
  34. (160_000, 16_000, 16_000) -> (2000, 200, 200) for instance.
  35. """
  36. # Shape is [B, SEQ] for tokens
  37. # [B, SEQ, V] for logits
  38. new_strides = []
  39. for input_n, left, right in stride:
  40. token_n = int(round(input_n * ratio))
  41. left = int(round(left / input_n * token_n))
  42. right = int(round(right / input_n * token_n))
  43. new_stride = (token_n, left, right)
  44. new_strides.append(new_stride)
  45. return new_strides
  46. def chunk_iter(inputs, feature_extractor, chunk_len, stride_left, stride_right, dtype=None):
  47. inputs_len = inputs.shape[0]
  48. step = chunk_len - stride_left - stride_right
  49. for chunk_start_idx in range(0, inputs_len, step):
  50. chunk_end_idx = chunk_start_idx + chunk_len
  51. chunk = inputs[chunk_start_idx:chunk_end_idx]
  52. processed = feature_extractor(
  53. chunk,
  54. sampling_rate=feature_extractor.sampling_rate,
  55. return_tensors="pt",
  56. return_attention_mask=True,
  57. )
  58. if dtype is not None:
  59. processed = processed.to(dtype=dtype)
  60. _stride_left = 0 if chunk_start_idx == 0 else stride_left
  61. is_last = chunk_end_idx >= inputs_len
  62. _stride_right = 0 if is_last else stride_right
  63. chunk_len = chunk.shape[0]
  64. stride = (chunk_len, _stride_left, _stride_right)
  65. if chunk.shape[0] > _stride_left:
  66. yield {"is_last": is_last, "stride": stride, **processed}
  67. if is_last:
  68. break
  69. def _find_longest_common_sequence(sequences, tokenizer):
  70. # TODO Use a faster algorithm this can probably be done in O(n)
  71. # using suffix array.
  72. # It might be tedious to do because of fault tolerance.
  73. # We actually have a really good property which is that the total sequence
  74. # MUST be those subsequences in order.
  75. # Also the algorithm should be more tolerant to errors.
  76. sequence = [tok_id for tok_id in sequences[0][0].tolist() if tok_id not in tokenizer.all_special_ids]
  77. for new_seq in sequences[1:]:
  78. new_sequence = [tok_id for tok_id in new_seq[0].tolist() if tok_id not in tokenizer.all_special_ids]
  79. index = 0
  80. max_ = 0.0
  81. for i in range(1, len(new_sequence) + 1):
  82. # epsilon to favor long perfect matches
  83. eps = i / 10000.0
  84. matches = np.sum(np.array(sequence[-i:]) == np.array(new_sequence[:i]))
  85. matching = matches / i + eps
  86. if matches > 1 and matching > max_:
  87. index = i
  88. max_ = matching
  89. sequence.extend(new_sequence[index:])
  90. return np.array(sequence)
  91. class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
  92. """
  93. Pipeline that aims at extracting spoken text contained within some audio.
  94. The input can be either a raw waveform or a audio file. In case of the audio file, ffmpeg should be installed for
  95. to support multiple audio formats
  96. Unless the model you're using explicitly sets these generation parameters in its configuration files
  97. (`generation_config.json`), the following default values will be used:
  98. - max_new_tokens: 256
  99. - num_beams: 5
  100. Example:
  101. ```python
  102. >>> from transformers import pipeline
  103. >>> transcriber = pipeline(model="openai/whisper-base")
  104. >>> transcriber("https://huggingface.co/datasets/Narsil/asr_dummy/resolve/main/1.flac")
  105. {'text': ' He hoped there would be stew for dinner, turnips and carrots and bruised potatoes and fat mutton pieces to be ladled out in thick, peppered flour-fatten sauce.'}
  106. ```
  107. Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
  108. Arguments:
  109. model ([`PreTrainedModel`] or [`TFPreTrainedModel`]):
  110. The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
  111. [`PreTrainedModel`] for PyTorch and [`TFPreTrainedModel`] for TensorFlow.
  112. feature_extractor ([`SequenceFeatureExtractor`]):
  113. The feature extractor that will be used by the pipeline to encode waveform for the model.
  114. tokenizer ([`PreTrainedTokenizer`]):
  115. The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
  116. [`PreTrainedTokenizer`].
  117. decoder (`pyctcdecode.BeamSearchDecoderCTC`, *optional*):
  118. [PyCTCDecode's
  119. BeamSearchDecoderCTC](https://github.com/kensho-technologies/pyctcdecode/blob/2fd33dc37c4111417e08d89ccd23d28e9b308d19/pyctcdecode/decoder.py#L180)
  120. can be passed for language model boosted decoding. See [`Wav2Vec2ProcessorWithLM`] for more information.
  121. chunk_length_s (`float`, *optional*, defaults to 0):
  122. The input length for in each chunk. If `chunk_length_s = 0` then chunking is disabled (default).
  123. <Tip>
  124. For more information on how to effectively use `chunk_length_s`, please have a look at the [ASR chunking
  125. blog post](https://huggingface.co/blog/asr-chunking).
  126. </Tip>
  127. stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`):
  128. The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables
  129. the model to *see* more context and infer letters better than without this context but the pipeline
  130. discards the stride bits at the end to make the final reconstitution as perfect as possible.
  131. <Tip>
  132. For more information on how to effectively use `stride_length_s`, please have a look at the [ASR chunking
  133. blog post](https://huggingface.co/blog/asr-chunking).
  134. </Tip>
  135. framework (`str`, *optional*):
  136. The framework to use, either `"pt"` for PyTorch or `"tf"` for TensorFlow. The specified framework must be
  137. installed. If no framework is specified, will default to the one currently installed. If no framework is
  138. specified and both frameworks are installed, will default to the framework of the `model`, or to PyTorch if
  139. no model is provided.
  140. device (Union[`int`, `torch.device`], *optional*):
  141. Device ordinal for CPU/GPU supports. Setting this to `None` will leverage CPU, a positive will run the
  142. model on the associated CUDA device id.
  143. """
  144. _pipeline_calls_generate = True
  145. _load_processor = False
  146. _load_image_processor = False
  147. _load_feature_extractor = True
  148. _load_tokenizer = True
  149. # Make sure the docstring is updated when the default generation config is changed
  150. _default_generation_config = GenerationConfig(
  151. max_new_tokens=256,
  152. num_beams=5, # follows openai's whisper implementation
  153. )
  154. def __init__(
  155. self,
  156. model: "PreTrainedModel",
  157. feature_extractor: Optional[Union["SequenceFeatureExtractor", str]] = None,
  158. tokenizer: Optional[PreTrainedTokenizer] = None,
  159. decoder: Optional[Union["BeamSearchDecoderCTC", str]] = None,
  160. device: Optional[Union[int, "torch.device"]] = None,
  161. **kwargs,
  162. ):
  163. # set the model type so we can check we have the right pre- and post-processing parameters
  164. if model.config.model_type == "whisper":
  165. self.type = "seq2seq_whisper"
  166. elif model.__class__.__name__ in MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES.values():
  167. self.type = "seq2seq"
  168. elif (
  169. feature_extractor._processor_class
  170. and feature_extractor._processor_class.endswith("WithLM")
  171. and decoder is not None
  172. ):
  173. self.decoder = decoder
  174. self.type = "ctc_with_lm"
  175. else:
  176. self.type = "ctc"
  177. super().__init__(model, tokenizer, feature_extractor, device=device, **kwargs)
  178. def __call__(self, inputs: Union[np.ndarray, bytes, str, dict], **kwargs: Any) -> list[dict[str, Any]]:
  179. """
  180. Transcribe the audio sequence(s) given as inputs to text. See the [`AutomaticSpeechRecognitionPipeline`]
  181. documentation for more information.
  182. Args:
  183. inputs (`np.ndarray` or `bytes` or `str` or `dict`):
  184. The inputs is either :
  185. - `str` that is either the filename of a local audio file, or a public URL address to download the
  186. audio file. The file will be read at the correct sampling rate to get the waveform using
  187. *ffmpeg*. This requires *ffmpeg* to be installed on the system.
  188. - `bytes` it is supposed to be the content of an audio file and is interpreted by *ffmpeg* in the
  189. same way.
  190. - (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
  191. Raw audio at the correct sampling rate (no further check will be done)
  192. - `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
  193. pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "raw":
  194. np.array}` with optionally a `"stride": (left: int, right: int)` than can ask the pipeline to
  195. treat the first `left` samples and last `right` samples to be ignored in decoding (but used at
  196. inference to provide more context to the model). Only use `stride` with CTC models.
  197. return_timestamps (*optional*, `str` or `bool`):
  198. Only available for pure CTC models (Wav2Vec2, HuBERT, etc) and the Whisper model. Not available for
  199. other sequence-to-sequence models.
  200. For CTC models, timestamps can take one of two formats:
  201. - `"char"`: the pipeline will return timestamps along the text for every character in the text. For
  202. instance, if you get `[{"text": "h", "timestamp": (0.5, 0.6)}, {"text": "i", "timestamp": (0.7,
  203. 0.9)}]`, then it means the model predicts that the letter "h" was spoken after `0.5` and before
  204. `0.6` seconds.
  205. - `"word"`: the pipeline will return timestamps along the text for every word in the text. For
  206. instance, if you get `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text": "there", "timestamp":
  207. (1.0, 1.5)}]`, then it means the model predicts that the word "hi" was spoken after `0.5` and
  208. before `0.9` seconds.
  209. For the Whisper model, timestamps can take one of two formats:
  210. - `"word"`: same as above for word-level CTC timestamps. Word-level timestamps are predicted
  211. through the *dynamic-time warping (DTW)* algorithm, an approximation to word-level timestamps
  212. by inspecting the cross-attention weights.
  213. - `True`: the pipeline will return timestamps along the text for *segments* of words in the text.
  214. For instance, if you get `[{"text": " Hi there!", "timestamp": (0.5, 1.5)}]`, then it means the
  215. model predicts that the segment "Hi there!" was spoken after `0.5` and before `1.5` seconds.
  216. Note that a segment of text refers to a sequence of one or more words, rather than individual
  217. words as with word-level timestamps.
  218. generate_kwargs (`dict`, *optional*):
  219. The dictionary of ad-hoc parametrization of `generate_config` to be used for the generation call. For a
  220. complete overview of generate, check the [following
  221. guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
  222. Return:
  223. `Dict`: A dictionary with the following keys:
  224. - **text** (`str`): The recognized text.
  225. - **chunks** (*optional(, `list[Dict]`)
  226. When using `return_timestamps`, the `chunks` will become a list containing all the various text
  227. chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamp": (0.5, 0.9)}, {"text":
  228. "there", "timestamp": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
  229. `"".join(chunk["text"] for chunk in output["chunks"])`.
  230. """
  231. return super().__call__(inputs, **kwargs)
  232. def _sanitize_parameters(
  233. self,
  234. chunk_length_s=None,
  235. stride_length_s=None,
  236. ignore_warning=None,
  237. decoder_kwargs=None,
  238. return_timestamps=None,
  239. return_language=None,
  240. **generate_kwargs,
  241. ):
  242. preprocess_params = {}
  243. forward_params = {}
  244. postprocess_params = {}
  245. # Preprocess params
  246. if chunk_length_s is not None:
  247. if self.type in ["seq2seq", "seq2seq_whisper"] and not ignore_warning:
  248. type_warning = (
  249. "Using `chunk_length_s` is very experimental with seq2seq models. The results will not necessarily"
  250. " be entirely accurate and will have caveats. More information:"
  251. " https://github.com/huggingface/transformers/pull/20104. Ignore this warning with pipeline(...,"
  252. " ignore_warning=True)."
  253. )
  254. if self.type == "seq2seq_whisper":
  255. type_warning += (
  256. " To use Whisper for long-form transcription, use rather the model's `generate` method directly "
  257. "as the model relies on it's own chunking mechanism (cf. Whisper original paper, section 3.8. "
  258. "Long-form Transcription)."
  259. )
  260. logger.warning(type_warning)
  261. preprocess_params["chunk_length_s"] = chunk_length_s
  262. if stride_length_s is not None:
  263. preprocess_params["stride_length_s"] = stride_length_s
  264. # Forward params
  265. # BC: accept a dictionary of generation kwargs (as opposed to **generate_kwargs)
  266. if "generate_kwargs" in generate_kwargs:
  267. forward_params.update(generate_kwargs.pop("generate_kwargs"))
  268. # Default use for kwargs: they are generation-time kwargs
  269. forward_params.update(generate_kwargs)
  270. if getattr(self, "assistant_model", None) is not None:
  271. forward_params["assistant_model"] = self.assistant_model
  272. if getattr(self, "assistant_tokenizer", None) is not None:
  273. forward_params["tokenizer"] = self.tokenizer
  274. forward_params["assistant_tokenizer"] = self.assistant_tokenizer
  275. # Postprocess params
  276. if decoder_kwargs is not None:
  277. postprocess_params["decoder_kwargs"] = decoder_kwargs
  278. if return_language is not None:
  279. if self.type != "seq2seq_whisper":
  280. raise ValueError("Only Whisper can return language for now.")
  281. postprocess_params["return_language"] = return_language
  282. # Parameter used in more than one place
  283. # in some models like whisper, the generation config has a `return_timestamps` key
  284. if hasattr(self, "generation_config") and hasattr(self.generation_config, "return_timestamps"):
  285. return_timestamps = return_timestamps or self.generation_config.return_timestamps
  286. if return_timestamps is not None:
  287. # Check whether we have a valid setting for return_timestamps and throw an error before we perform a forward pass
  288. if self.type == "seq2seq" and return_timestamps:
  289. raise ValueError("We cannot return_timestamps yet on non-CTC models apart from Whisper!")
  290. if self.type == "ctc_with_lm" and return_timestamps != "word":
  291. raise ValueError("CTC with LM can only predict word level timestamps, set `return_timestamps='word'`")
  292. if self.type == "ctc" and return_timestamps not in ["char", "word"]:
  293. raise ValueError(
  294. "CTC can either predict character level timestamps, or word level timestamps. "
  295. "Set `return_timestamps='char'` or `return_timestamps='word'` as required."
  296. )
  297. if self.type == "seq2seq_whisper" and return_timestamps == "char":
  298. raise ValueError(
  299. "Whisper cannot return `char` timestamps, only word level or segment level timestamps. "
  300. "Use `return_timestamps='word'` or `return_timestamps=True` respectively."
  301. )
  302. forward_params["return_timestamps"] = return_timestamps
  303. postprocess_params["return_timestamps"] = return_timestamps
  304. return preprocess_params, forward_params, postprocess_params
  305. def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):
  306. if isinstance(inputs, str):
  307. if inputs.startswith("http://") or inputs.startswith("https://"):
  308. # We need to actually check for a real protocol, otherwise it's impossible to use a local file
  309. # like http_huggingface_co.png
  310. inputs = requests.get(inputs).content
  311. else:
  312. with open(inputs, "rb") as f:
  313. inputs = f.read()
  314. if isinstance(inputs, bytes):
  315. inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
  316. stride = None
  317. extra = {}
  318. if is_torch_available():
  319. import torch
  320. if isinstance(inputs, torch.Tensor):
  321. inputs = inputs.cpu().numpy()
  322. if is_torchcodec_available():
  323. import torchcodec
  324. if isinstance(inputs, torchcodec.decoders.AudioDecoder):
  325. _audio_samples = inputs.get_all_samples()
  326. # torchcodec always returns (num_channels, num_samples)
  327. # while before (datasets < 4.0) we had (2, num_samples) if stereo, (num_samples,) if mono
  328. _array = _audio_samples.data
  329. _array = _array[0] if _array.ndim == 2 and _array.shape[0] == 1 else _array
  330. inputs = {"array": _array, "sampling_rate": _audio_samples.sample_rate}
  331. if isinstance(inputs, dict):
  332. stride = inputs.pop("stride", None)
  333. # Accepting `"array"` which is the key defined in `datasets` for
  334. # better integration
  335. if not ("sampling_rate" in inputs and ("raw" in inputs or "array" in inputs)):
  336. raise ValueError(
  337. "When passing a dictionary to AutomaticSpeechRecognitionPipeline, the dict needs to contain a "
  338. '"raw" key containing the numpy array or torch tensor representing the audio and a "sampling_rate" key, '
  339. "containing the sampling_rate associated with that array"
  340. )
  341. _inputs = inputs.pop("raw", None)
  342. if _inputs is None:
  343. # Remove path which will not be used from `datasets`.
  344. inputs.pop("path", None)
  345. _inputs = inputs.pop("array", None)
  346. in_sampling_rate = inputs.pop("sampling_rate")
  347. extra = inputs
  348. inputs = _inputs
  349. if in_sampling_rate != self.feature_extractor.sampling_rate:
  350. if is_torchaudio_available():
  351. from torchaudio import functional as F
  352. else:
  353. raise ImportError(
  354. "torchaudio is required to resample audio samples in AutomaticSpeechRecognitionPipeline. "
  355. "The torchaudio package can be installed through: `pip install torchaudio`."
  356. )
  357. inputs = F.resample(
  358. torch.from_numpy(inputs) if isinstance(inputs, np.ndarray) else inputs,
  359. in_sampling_rate,
  360. in_sampling_rate,
  361. self.feature_extractor.sampling_rate,
  362. ).numpy()
  363. ratio = self.feature_extractor.sampling_rate / in_sampling_rate
  364. else:
  365. ratio = 1
  366. if stride is not None:
  367. if stride[0] + stride[1] > inputs.shape[0]:
  368. raise ValueError("Stride is too large for input")
  369. # Stride needs to get the chunk length here, it's going to get
  370. # swallowed by the `feature_extractor` later, and then batching
  371. # can add extra data in the inputs, so we need to keep track
  372. # of the original length in the stride so we can cut properly.
  373. stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
  374. if not isinstance(inputs, (np.ndarray, torch.Tensor)):
  375. raise TypeError(f"We expect a numpy ndarray or torch tensor as input, got `{type(inputs)}`")
  376. if inputs.ndim != 1:
  377. logger.warning(
  378. f"We expect a single channel audio input for AutomaticSpeechRecognitionPipeline, got {inputs.ndim}. Taking the mean of the channels for mono conversion."
  379. )
  380. inputs = inputs.mean(axis=0)
  381. if chunk_length_s:
  382. if stride_length_s is None:
  383. stride_length_s = chunk_length_s / 6
  384. if isinstance(stride_length_s, (int, float)):
  385. stride_length_s = [stride_length_s, stride_length_s]
  386. # XXX: Carefully, this variable will not exist in `seq2seq` setting.
  387. # Currently chunking is not possible at this level for `seq2seq` so
  388. # it's ok.
  389. align_to = getattr(self.model.config, "inputs_to_logits_ratio", 1)
  390. chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
  391. stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
  392. stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
  393. if chunk_len < stride_left + stride_right:
  394. raise ValueError("Chunk length must be superior to stride length")
  395. for item in chunk_iter(inputs, self.feature_extractor, chunk_len, stride_left, stride_right, self.dtype):
  396. yield {**item, **extra}
  397. else:
  398. if self.type == "seq2seq_whisper" and inputs.shape[0] > self.feature_extractor.n_samples:
  399. processed = self.feature_extractor(
  400. inputs,
  401. sampling_rate=self.feature_extractor.sampling_rate,
  402. truncation=False,
  403. padding="longest",
  404. return_tensors="pt",
  405. return_attention_mask=True,
  406. )
  407. else:
  408. if self.type == "seq2seq_whisper" and stride is None:
  409. processed = self.feature_extractor(
  410. inputs,
  411. sampling_rate=self.feature_extractor.sampling_rate,
  412. return_tensors="pt",
  413. return_token_timestamps=True,
  414. return_attention_mask=True,
  415. )
  416. extra["num_frames"] = processed.pop("num_frames")
  417. else:
  418. processed = self.feature_extractor(
  419. inputs,
  420. sampling_rate=self.feature_extractor.sampling_rate,
  421. return_tensors="pt",
  422. return_attention_mask=True,
  423. )
  424. if self.dtype is not None:
  425. processed = processed.to(dtype=self.dtype)
  426. if stride is not None:
  427. if self.type == "seq2seq":
  428. raise ValueError("Stride is only usable with CTC models, try removing it !")
  429. processed["stride"] = stride
  430. yield {"is_last": True, **processed, **extra}
  431. def _forward(self, model_inputs, return_timestamps=False, **generate_kwargs):
  432. attention_mask = model_inputs.pop("attention_mask", None)
  433. stride = model_inputs.pop("stride", None)
  434. num_frames = model_inputs.pop("num_frames", None)
  435. is_last = model_inputs.pop("is_last")
  436. if stride is not None and num_frames is not None:
  437. raise ValueError("num_frames must be used only when stride is None")
  438. if self.type in {"seq2seq", "seq2seq_whisper"}:
  439. # Consume values so we can let extra information flow freely through
  440. # the pipeline (important for `partial` in microphone)
  441. if "input_features" in model_inputs:
  442. inputs = model_inputs.pop("input_features")
  443. elif "input_values" in model_inputs:
  444. inputs = model_inputs.pop("input_values")
  445. else:
  446. raise ValueError(
  447. "Seq2Seq speech recognition model requires either a "
  448. f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
  449. )
  450. # custom processing for Whisper timestamps and word-level timestamps
  451. return_timestamps = return_timestamps or getattr(self.generation_config, "return_timestamps", False)
  452. if return_timestamps and self.type == "seq2seq_whisper":
  453. generate_kwargs["return_timestamps"] = bool(return_timestamps)
  454. if return_timestamps == "word":
  455. generate_kwargs["return_token_timestamps"] = True
  456. generate_kwargs["return_segments"] = True
  457. # User-defined `generation_config` passed to the pipeline call take precedence
  458. if "generation_config" not in generate_kwargs:
  459. generate_kwargs["generation_config"] = self.generation_config
  460. main_input_name = self.model.main_input_name if hasattr(self.model, "main_input_name") else "inputs"
  461. generate_kwargs = {
  462. main_input_name: inputs,
  463. "attention_mask": attention_mask,
  464. **generate_kwargs,
  465. }
  466. tokens = self.model.generate(**generate_kwargs)
  467. # whisper longform generation stores timestamps in "segments"
  468. if return_timestamps == "word" and self.type == "seq2seq_whisper":
  469. if "segments" not in tokens:
  470. out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
  471. else:
  472. token_timestamps = [
  473. torch.cat([segment["token_timestamps"] for segment in segment_list])
  474. for segment_list in tokens["segments"]
  475. ]
  476. out = {"tokens": tokens["sequences"], "token_timestamps": token_timestamps}
  477. else:
  478. out = {"tokens": tokens}
  479. if self.type == "seq2seq_whisper":
  480. if stride is not None:
  481. out["stride"] = stride
  482. else:
  483. inputs = {
  484. self.model.main_input_name: model_inputs.pop(self.model.main_input_name),
  485. "attention_mask": attention_mask,
  486. }
  487. outputs = self.model(**inputs)
  488. logits = outputs.logits
  489. if self.type == "ctc_with_lm":
  490. out = {"logits": logits}
  491. else:
  492. out = {"tokens": logits.argmax(dim=-1)}
  493. if stride is not None:
  494. # Send stride to `postprocess`.
  495. # it needs to be handled there where
  496. # the pieces are to be concatenated.
  497. ratio = 1 / self.model.config.inputs_to_logits_ratio
  498. if isinstance(stride, tuple):
  499. out["stride"] = rescale_stride([stride], ratio)[0]
  500. else:
  501. out["stride"] = rescale_stride(stride, ratio)
  502. # Leftover
  503. extra = model_inputs
  504. return {"is_last": is_last, **out, **extra}
  505. def postprocess(
  506. self, model_outputs, decoder_kwargs: Optional[dict] = None, return_timestamps=None, return_language=None
  507. ):
  508. # Optional return types
  509. optional = {}
  510. final_items = []
  511. key = "logits" if self.type == "ctc_with_lm" else "tokens"
  512. stride = None
  513. for outputs in model_outputs:
  514. if self.framework == "pt" and outputs[key].dtype in (torch.bfloat16, torch.float16):
  515. items = outputs[key].to(torch.float32).numpy()
  516. else:
  517. items = outputs[key].numpy()
  518. stride = outputs.get("stride", None)
  519. if stride is not None and self.type in {"ctc", "ctc_with_lm"}:
  520. total_n, left, right = stride
  521. # Total_n might be < logits.shape[1]
  522. # because of padding, that's why
  523. # we need to reconstruct this information
  524. # This won't work with left padding (which doesn't exist right now)
  525. right_n = total_n - right
  526. items = items[:, left:right_n]
  527. final_items.append(items)
  528. if stride and self.type == "seq2seq":
  529. items = _find_longest_common_sequence(final_items, self.tokenizer)
  530. elif self.type == "seq2seq_whisper":
  531. time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions
  532. # Send the chunking back to seconds, it's easier to handle in whisper
  533. sampling_rate = self.feature_extractor.sampling_rate
  534. for output in model_outputs:
  535. if "stride" in output:
  536. chunk_len, stride_left, stride_right = output["stride"]
  537. # Go back in seconds
  538. chunk_len /= sampling_rate
  539. stride_left /= sampling_rate
  540. stride_right /= sampling_rate
  541. output["stride"] = chunk_len, stride_left, stride_right
  542. text, optional = self.tokenizer._decode_asr(
  543. model_outputs,
  544. return_timestamps=return_timestamps,
  545. return_language=return_language,
  546. time_precision=time_precision,
  547. )
  548. else:
  549. items = np.concatenate(final_items, axis=1)
  550. items = items.squeeze(0)
  551. if self.type == "ctc_with_lm":
  552. if decoder_kwargs is None:
  553. decoder_kwargs = {}
  554. beams = self.decoder.decode_beams(items, **decoder_kwargs)
  555. text = beams[0][0]
  556. if return_timestamps:
  557. # Simply cast from pyctcdecode format to wav2vec2 format to leverage
  558. # pre-existing code later
  559. chunk_offset = beams[0][2]
  560. offsets = []
  561. for word, (start_offset, end_offset) in chunk_offset:
  562. offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
  563. elif self.type != "seq2seq_whisper":
  564. skip_special_tokens = self.type != "ctc"
  565. text = self.tokenizer.decode(items, skip_special_tokens=skip_special_tokens)
  566. if return_timestamps:
  567. offsets = self.tokenizer.decode(
  568. items, skip_special_tokens=skip_special_tokens, output_char_offsets=True
  569. )["char_offsets"]
  570. if return_timestamps == "word":
  571. offsets = self.tokenizer._get_word_offsets(offsets, self.tokenizer.replace_word_delimiter_char)
  572. if return_timestamps and self.type not in {"seq2seq", "seq2seq_whisper"}:
  573. chunks = []
  574. for item in offsets:
  575. start = item["start_offset"] * self.model.config.inputs_to_logits_ratio
  576. start /= self.feature_extractor.sampling_rate
  577. stop = item["end_offset"] * self.model.config.inputs_to_logits_ratio
  578. stop /= self.feature_extractor.sampling_rate
  579. chunks.append({"text": item[return_timestamps], "timestamp": (start, stop)})
  580. optional["chunks"] = chunks
  581. extra = defaultdict(list)
  582. for output in model_outputs:
  583. output.pop("tokens", None)
  584. output.pop("logits", None)
  585. output.pop("is_last", None)
  586. output.pop("stride", None)
  587. output.pop("token_timestamps", None)
  588. for k, v in output.items():
  589. extra[k].append(v)
  590. return {"text": text, **optional, **extra}