tokenization_wav2vec2.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924
  1. # coding=utf-8
  2. # Copyright 2021 The Facebook Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Tokenization class for Wav2Vec2."""
  16. import json
  17. import os
  18. import warnings
  19. from dataclasses import dataclass
  20. from itertools import groupby
  21. from typing import TYPE_CHECKING, Optional, Union
  22. import numpy as np
  23. from ...tokenization_utils import PreTrainedTokenizer
  24. from ...tokenization_utils_base import AddedToken, BatchEncoding
  25. from ...utils import (
  26. ModelOutput,
  27. PaddingStrategy,
  28. TensorType,
  29. add_end_docstrings,
  30. is_flax_available,
  31. is_tf_available,
  32. is_torch_available,
  33. logging,
  34. to_py_obj,
  35. )
  36. logger = logging.get_logger(__name__)
  37. if TYPE_CHECKING:
  38. if is_torch_available():
  39. import torch
  40. if is_tf_available():
  41. import tensorflow as tf
  42. if is_flax_available():
  43. import jax.numpy as jnp # noqa: F401
  44. VOCAB_FILES_NAMES = {
  45. "vocab_file": "vocab.json",
  46. "tokenizer_config_file": "tokenizer_config.json",
  47. }
  48. # Wav2Vec2 has no max input length
  49. WAV2VEC2_KWARGS_DOCSTRING = r"""
  50. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
  51. Activates and controls padding. Accepts the following values:
  52. - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
  53. sequence if provided).
  54. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  55. acceptable input length for the model if that argument is not provided.
  56. - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
  57. lengths).
  58. max_length (`int`, *optional*):
  59. Controls the maximum length to use by one of the truncation/padding parameters.
  60. If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
  61. is required by one of the truncation/padding parameters. If the model has no specific maximum input
  62. length (like XLNet) truncation/padding to a maximum length will be deactivated.
  63. pad_to_multiple_of (`int`, *optional*):
  64. If set will pad the sequence to a multiple of the provided value. This is especially useful to enable
  65. the use of Tensor Cores on NVIDIA hardware with compute capability `>= 7.5` (Volta).
  66. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  67. If set, will return tensors instead of list of python integers. Acceptable values are:
  68. - `'tf'`: Return TensorFlow `tf.constant` objects.
  69. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  70. - `'np'`: Return Numpy `np.ndarray` objects.
  71. verbose (`bool`, *optional*, defaults to `True`):
  72. Whether or not to print more information and warnings.
  73. """
  74. ListOfDict = list[dict[str, Union[int, str]]]
  75. @dataclass
  76. class Wav2Vec2CTCTokenizerOutput(ModelOutput):
  77. """
  78. Output type of [` Wav2Vec2CTCTokenizer`], with transcription.
  79. Args:
  80. text (list of `str` or `str`):
  81. Decoded logits in text from. Usually the speech transcription.
  82. char_offsets (list of `list[dict[str, Union[int, str]]]` or `list[dict[str, Union[int, str]]]`):
  83. Offsets of the decoded characters. In combination with sampling rate and model downsampling rate char
  84. offsets can be used to compute time stamps for each character. Total logit score of the beam associated with
  85. produced text.
  86. word_offsets (list of `list[dict[str, Union[int, str]]]` or `list[dict[str, Union[int, str]]]`):
  87. Offsets of the decoded words. In combination with sampling rate and model downsampling rate word offsets
  88. can be used to compute time stamps for each word.
  89. """
  90. text: Union[list[str], str]
  91. char_offsets: Union[list[ListOfDict], ListOfDict] = None
  92. word_offsets: Union[list[ListOfDict], ListOfDict] = None
  93. class Wav2Vec2CTCTokenizer(PreTrainedTokenizer):
  94. """
  95. Constructs a Wav2Vec2CTC tokenizer.
  96. This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to
  97. the superclass for more information regarding such methods.
  98. Args:
  99. vocab_file (`str`):
  100. File containing the vocabulary.
  101. bos_token (`str`, *optional*, defaults to `"<s>"`):
  102. The beginning of sentence token.
  103. eos_token (`str`, *optional*, defaults to `"</s>"`):
  104. The end of sentence token.
  105. unk_token (`str`, *optional*, defaults to `"<unk>"`):
  106. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  107. token instead.
  108. pad_token (`str`, *optional*, defaults to `"<pad>"`):
  109. The token used for padding, for example when batching sequences of different lengths.
  110. word_delimiter_token (`str`, *optional*, defaults to `"|"`):
  111. The token used for defining the end of a word.
  112. do_lower_case (`bool`, *optional*, defaults to `False`):
  113. Whether or not to accept lowercase input and lowercase the output when decoding.
  114. target_lang (`str`, *optional*):
  115. A target language the tokenizer should set by default. `target_lang` has to be defined for multi-lingual,
  116. nested vocabulary such as [facebook/mms-1b-all](https://huggingface.co/facebook/mms-1b-all).
  117. **kwargs
  118. Additional keyword arguments passed along to [`PreTrainedTokenizer`]
  119. """
  120. vocab_files_names = VOCAB_FILES_NAMES
  121. model_input_names = ["input_ids", "attention_mask"]
  122. def __init__(
  123. self,
  124. vocab_file,
  125. bos_token="<s>",
  126. eos_token="</s>",
  127. unk_token="<unk>",
  128. pad_token="<pad>",
  129. word_delimiter_token="|",
  130. replace_word_delimiter_char=" ",
  131. do_lower_case=False,
  132. target_lang=None,
  133. **kwargs,
  134. ):
  135. self._word_delimiter_token = word_delimiter_token
  136. self.do_lower_case = do_lower_case
  137. self.replace_word_delimiter_char = replace_word_delimiter_char
  138. self.target_lang = target_lang
  139. with open(vocab_file, encoding="utf-8") as vocab_handle:
  140. self.vocab = json.load(vocab_handle)
  141. # if target lang is defined vocab must be a nested dict
  142. # with each target lang being one vocabulary
  143. if target_lang is not None:
  144. self.encoder = self.vocab[target_lang]
  145. else:
  146. self.encoder = self.vocab
  147. self.decoder = {v: k for k, v in self.encoder.items()}
  148. super().__init__(
  149. unk_token=unk_token,
  150. bos_token=bos_token,
  151. eos_token=eos_token,
  152. pad_token=pad_token,
  153. do_lower_case=do_lower_case,
  154. word_delimiter_token=word_delimiter_token,
  155. replace_word_delimiter_char=replace_word_delimiter_char,
  156. target_lang=target_lang,
  157. **kwargs,
  158. )
  159. # make sure that tokens made of several
  160. # characters are not split at tokenization
  161. for token in self.encoder:
  162. if len(token) > 1:
  163. self.add_tokens(AddedToken(token, rstrip=True, lstrip=True, normalized=False))
  164. def set_target_lang(self, target_lang: str):
  165. """
  166. Set the target language of a nested multi-lingual dictionary
  167. """
  168. if self.vocab == self.encoder:
  169. raise ValueError(f"{self.vocab} is not a multi-lingual, nested tokenizer. Cannot set target language.")
  170. if target_lang not in self.vocab:
  171. raise ValueError(f"{target_lang} does not exist. Choose one of {', '.join(self.vocab.keys())}.")
  172. self.target_lang = target_lang
  173. self.init_kwargs["target_lang"] = target_lang
  174. self.encoder = self.vocab[target_lang]
  175. self.decoder = {v: k for k, v in self.encoder.items()}
  176. # make sure that tokens made of several
  177. # characters are not split at tokenization
  178. for token in self.encoder:
  179. if len(token) > 1:
  180. self.add_tokens(AddedToken(token, rstrip=True, lstrip=True, normalized=False))
  181. @property
  182. def word_delimiter_token(self) -> str:
  183. """
  184. `str`: Word delimiter token. Log an error if used while not having been set.
  185. """
  186. if self._word_delimiter_token is None and self.verbose:
  187. logger.error("Using word_delimiter_token, but it is not set yet.")
  188. return None
  189. return str(self._word_delimiter_token)
  190. @property
  191. def word_delimiter_token_id(self) -> Optional[int]:
  192. """
  193. `Optional[int]`: Id of the word_delimiter_token in the vocabulary. Returns `None` if the token has not been
  194. set.
  195. """
  196. if self._word_delimiter_token is None:
  197. return None
  198. return self.convert_tokens_to_ids(self.word_delimiter_token)
  199. @word_delimiter_token.setter
  200. def word_delimiter_token(self, value):
  201. self._word_delimiter_token = value
  202. @word_delimiter_token_id.setter
  203. def word_delimiter_token_id(self, value):
  204. self._word_delimiter_token = self.convert_tokens_to_ids(value)
  205. @property
  206. def vocab_size(self) -> int:
  207. return len(self.decoder)
  208. def get_vocab(self) -> dict:
  209. vocab = dict(self.encoder)
  210. vocab.update(self.added_tokens_encoder)
  211. return vocab
  212. def _add_tokens(self, new_tokens: Union[list[str], list[AddedToken]], special_tokens: bool = False) -> int:
  213. # Overwritten to never strip!
  214. to_add = []
  215. for token in new_tokens:
  216. if isinstance(token, str):
  217. to_add.append(AddedToken(token, rstrip=False, lstrip=False, normalized=False))
  218. else:
  219. to_add.append(token)
  220. return super()._add_tokens(to_add, special_tokens)
  221. def _tokenize(self, text, **kwargs):
  222. """
  223. Converts a string into a sequence of tokens (string), using the tokenizer.
  224. """
  225. if self.do_lower_case:
  226. text = text.upper()
  227. return list(text.replace(" ", self.word_delimiter_token))
  228. def _convert_token_to_id(self, token: str) -> int:
  229. """Converts a token (str) in an index (integer) using the vocab."""
  230. return self.encoder.get(token, self.encoder.get(self.unk_token))
  231. def _convert_id_to_token(self, index: int) -> str:
  232. """Converts an index (integer) in a token (str) using the vocab."""
  233. result = self.decoder.get(index, self.unk_token)
  234. return result
  235. def convert_tokens_to_string(
  236. self,
  237. tokens: list[str],
  238. group_tokens: bool = True,
  239. spaces_between_special_tokens: bool = False,
  240. output_char_offsets: bool = False,
  241. output_word_offsets: bool = False,
  242. ) -> dict[str, Union[str, float]]:
  243. """
  244. Converts a connectionist-temporal-classification (CTC) output tokens into a single string.
  245. """
  246. if len(tokens) == 0:
  247. return {"text": "", "char_offsets": [], "word_offsets": []}
  248. # group same tokens into non-repeating tokens in CTC style decoding
  249. if group_tokens:
  250. chars, char_repetitions = zip(*((token, len(list(group_iter))) for token, group_iter in groupby(tokens)))
  251. else:
  252. chars = tokens
  253. char_repetitions = len(tokens) * [1]
  254. # filter self.pad_token which is used as CTC-blank token
  255. processed_chars = list(filter(lambda char: char != self.pad_token, chars))
  256. # replace delimiter token
  257. processed_chars = [
  258. self.replace_word_delimiter_char if char == self.word_delimiter_token else char for char in processed_chars
  259. ]
  260. # retrieve offsets
  261. char_offsets = word_offsets = None
  262. if output_char_offsets or output_word_offsets:
  263. char_offsets = self._compute_offsets(char_repetitions, chars, self.pad_token)
  264. if len(char_offsets) != len(processed_chars):
  265. raise ValueError(
  266. f"`char_offsets`: {char_offsets} and `processed_tokens`: {processed_chars}"
  267. " have to be of the same length, but are: "
  268. f"`len(offsets)`: {len(char_offsets)} and `len(processed_tokens)`:"
  269. f" {len(processed_chars)}"
  270. )
  271. # set tokens to correct processed token
  272. for i, char in enumerate(processed_chars):
  273. char_offsets[i]["char"] = char
  274. # retrieve word offsets from character offsets
  275. word_offsets = None
  276. if output_word_offsets:
  277. word_offsets = self._get_word_offsets(char_offsets, self.replace_word_delimiter_char)
  278. # don't output chars if not set to True
  279. if not output_char_offsets:
  280. char_offsets = None
  281. # join to string
  282. join_char = " " if spaces_between_special_tokens else ""
  283. string = join_char.join(processed_chars).strip()
  284. if self.do_lower_case:
  285. string = string.lower()
  286. return {"text": string, "char_offsets": char_offsets, "word_offsets": word_offsets}
  287. @staticmethod
  288. def _compute_offsets(
  289. char_repetitions: list[int], chars: list[str], ctc_token: int
  290. ) -> list[dict[str, Union[str, int]]]:
  291. end_indices = np.asarray(char_repetitions).cumsum()
  292. start_indices = np.concatenate(([0], end_indices[:-1]))
  293. offsets = [
  294. {"char": t, "start_offset": s, "end_offset": e} for t, s, e in zip(chars, start_indices, end_indices)
  295. ]
  296. # filter out CTC token
  297. offsets = list(filter(lambda offsets: offsets["char"] != ctc_token, offsets))
  298. return offsets
  299. @staticmethod
  300. def _get_word_offsets(
  301. offsets: dict[str, Union[str, float]], word_delimiter_char: str = " "
  302. ) -> dict[str, Union[str, float]]:
  303. word_offsets = []
  304. last_state = "SPACE"
  305. word = ""
  306. start_offset = 0
  307. end_offset = 0
  308. for i, offset in enumerate(offsets):
  309. char = offset["char"]
  310. state = "SPACE" if char == word_delimiter_char else "WORD"
  311. if state == last_state:
  312. # If we are in the same state as before, we simply repeat what we've done before
  313. end_offset = offset["end_offset"]
  314. word += char
  315. else:
  316. # Switching state
  317. if state == "SPACE":
  318. # Finishing a word
  319. word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
  320. else:
  321. # Starting a new word
  322. start_offset = offset["start_offset"]
  323. end_offset = offset["end_offset"]
  324. word = char
  325. last_state = state
  326. if last_state == "WORD":
  327. word_offsets.append({"word": word, "start_offset": start_offset, "end_offset": end_offset})
  328. return word_offsets
  329. def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
  330. if is_split_into_words:
  331. text = " " + text
  332. return (text, kwargs)
  333. def _decode(
  334. self,
  335. token_ids: list[int],
  336. skip_special_tokens: bool = False,
  337. clean_up_tokenization_spaces: Optional[bool] = None,
  338. group_tokens: bool = True,
  339. spaces_between_special_tokens: bool = False,
  340. output_word_offsets: Optional[bool] = False,
  341. output_char_offsets: Optional[bool] = False,
  342. ) -> str:
  343. """
  344. special _decode function is needed for Wav2Vec2Tokenizer because added tokens should be treated exactly the
  345. same as tokens of the base vocabulary and therefore the function `convert_tokens_to_string` has to be called on
  346. the whole token list and not individually on added tokens
  347. """
  348. filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
  349. result = []
  350. for token in filtered_tokens:
  351. if skip_special_tokens and (
  352. token in self.all_special_ids or (token != self.pad_token and token in self.all_special_tokens)
  353. ):
  354. continue
  355. result.append(token)
  356. string_output = self.convert_tokens_to_string(
  357. result,
  358. group_tokens=group_tokens,
  359. spaces_between_special_tokens=spaces_between_special_tokens,
  360. output_word_offsets=output_word_offsets,
  361. output_char_offsets=output_char_offsets,
  362. )
  363. text = string_output["text"]
  364. clean_up_tokenization_spaces = (
  365. clean_up_tokenization_spaces
  366. if clean_up_tokenization_spaces is not None
  367. else self.clean_up_tokenization_spaces
  368. )
  369. if clean_up_tokenization_spaces:
  370. text = self.clean_up_tokenization(text)
  371. if output_word_offsets or output_char_offsets:
  372. return Wav2Vec2CTCTokenizerOutput(
  373. text=text,
  374. char_offsets=string_output["char_offsets"],
  375. word_offsets=string_output["word_offsets"],
  376. )
  377. else:
  378. return text
  379. # overwritten from `tokenization_utils_base.py` because tokenizer can output
  380. # `ModelOutput` which should not be a list for batched output and
  381. # because we need docs for `output_char_offsets` here
  382. def batch_decode(
  383. self,
  384. sequences: Union[list[int], list[list[int]], "np.ndarray", "torch.Tensor", "tf.Tensor"],
  385. skip_special_tokens: bool = False,
  386. clean_up_tokenization_spaces: Optional[bool] = None,
  387. output_char_offsets: bool = False,
  388. output_word_offsets: bool = False,
  389. **kwargs,
  390. ) -> list[str]:
  391. """
  392. Convert a list of lists of token ids into a list of strings by calling decode.
  393. Args:
  394. sequences (`Union[list[int], list[list[int]], np.ndarray, torch.Tensor, tf.Tensor]`):
  395. List of tokenized input ids. Can be obtained using the `__call__` method.
  396. skip_special_tokens (`bool`, *optional*, defaults to `False`):
  397. Whether or not to remove special tokens in the decoding.
  398. clean_up_tokenization_spaces (`bool`, *optional*):
  399. Whether or not to clean up the tokenization spaces.
  400. output_char_offsets (`bool`, *optional*, defaults to `False`):
  401. Whether or not to output character offsets. Character offsets can be used in combination with the
  402. sampling rate and model downsampling rate to compute the time-stamps of transcribed characters.
  403. <Tip>
  404. Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make
  405. use of `output_char_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched
  406. output.
  407. </Tip>
  408. output_word_offsets (`bool`, *optional*, defaults to `False`):
  409. Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
  410. and model downsampling rate to compute the time-stamps of transcribed words.
  411. <Tip>
  412. Please take a look at the Example of [`~Wav2Vec2CTCTokenizer.decode`] to better understand how to make
  413. use of `output_word_offsets`. [`~Wav2Vec2CTCTokenizer.batch_decode`] works the same way with batched
  414. output.
  415. </Tip>
  416. kwargs (additional keyword arguments, *optional*):
  417. Will be passed to the underlying model specific decode method.
  418. Returns:
  419. `list[str]` or [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`]: The list of decoded
  420. sentences. Will be a [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`] when
  421. `output_char_offsets == True` or `output_word_offsets == True`.
  422. """
  423. batch_decoded = [
  424. self.decode(
  425. seq,
  426. skip_special_tokens=skip_special_tokens,
  427. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  428. output_char_offsets=output_char_offsets,
  429. output_word_offsets=output_word_offsets,
  430. **kwargs,
  431. )
  432. for seq in sequences
  433. ]
  434. if output_char_offsets or output_word_offsets:
  435. # transform list of dicts to dict of lists
  436. return Wav2Vec2CTCTokenizerOutput({k: [d[k] for d in batch_decoded] for k in batch_decoded[0]})
  437. return batch_decoded
  438. # overwritten from `tokenization_utils_base.py` because we need docs for `output_char_offsets`
  439. # and `output_word_offsets` here
  440. def decode(
  441. self,
  442. token_ids: Union[int, list[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
  443. skip_special_tokens: bool = False,
  444. clean_up_tokenization_spaces: Optional[bool] = None,
  445. output_char_offsets: bool = False,
  446. output_word_offsets: bool = False,
  447. **kwargs,
  448. ) -> str:
  449. """
  450. Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
  451. tokens and clean up tokenization spaces.
  452. Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
  453. Args:
  454. token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`):
  455. List of tokenized input ids. Can be obtained using the `__call__` method.
  456. skip_special_tokens (`bool`, *optional*, defaults to `False`):
  457. Whether or not to remove special tokens in the decoding.
  458. clean_up_tokenization_spaces (`bool`, *optional*):
  459. Whether or not to clean up the tokenization spaces.
  460. output_char_offsets (`bool`, *optional*, defaults to `False`):
  461. Whether or not to output character offsets. Character offsets can be used in combination with the
  462. sampling rate and model downsampling rate to compute the time-stamps of transcribed characters.
  463. <Tip>
  464. Please take a look at the example below to better understand how to make use of `output_char_offsets`.
  465. </Tip>
  466. output_word_offsets (`bool`, *optional*, defaults to `False`):
  467. Whether or not to output word offsets. Word offsets can be used in combination with the sampling rate
  468. and model downsampling rate to compute the time-stamps of transcribed words.
  469. <Tip>
  470. Please take a look at the example below to better understand how to make use of `output_word_offsets`.
  471. </Tip>
  472. kwargs (additional keyword arguments, *optional*):
  473. Will be passed to the underlying model specific decode method.
  474. Returns:
  475. `str` or [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`]: The list of decoded
  476. sentences. Will be a [`~models.wav2vec2.tokenization_wav2vec2.Wav2Vec2CTCTokenizerOutput`] when
  477. `output_char_offsets == True` or `output_word_offsets == True`.
  478. Example:
  479. ```python
  480. >>> # Let's see how to retrieve time steps for a model
  481. >>> from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
  482. >>> from datasets import load_dataset
  483. >>> import datasets
  484. >>> import torch
  485. >>> # import model, feature extractor, tokenizer
  486. >>> model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-base-960h")
  487. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-base-960h")
  488. >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
  489. >>> # load first sample of English common_voice
  490. >>> dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", split="train", streaming=True)
  491. >>> dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=16_000))
  492. >>> dataset_iter = iter(dataset)
  493. >>> sample = next(dataset_iter)
  494. >>> # forward sample through model to get greedily predicted transcription ids
  495. >>> input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values
  496. >>> logits = model(input_values).logits[0]
  497. >>> pred_ids = torch.argmax(logits, axis=-1)
  498. >>> # retrieve word stamps (analogous commands for `output_char_offsets`)
  499. >>> outputs = tokenizer.decode(pred_ids, output_word_offsets=True)
  500. >>> # compute `time_offset` in seconds as product of downsampling ratio and sampling_rate
  501. >>> time_offset = model.config.inputs_to_logits_ratio / feature_extractor.sampling_rate
  502. >>> word_offsets = [
  503. ... {
  504. ... "word": d["word"],
  505. ... "start_time": round(d["start_offset"] * time_offset, 2),
  506. ... "end_time": round(d["end_offset"] * time_offset, 2),
  507. ... }
  508. ... for d in outputs.word_offsets
  509. ... ]
  510. >>> # compare word offsets with audio `en_train_0/common_voice_en_19121553.mp3` online on the dataset viewer:
  511. >>> # https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0/viewer/en
  512. >>> word_offsets[:3]
  513. [{'word': 'THE', 'start_time': 0.7, 'end_time': 0.78}, {'word': 'TRICK', 'start_time': 0.88, 'end_time': 1.08}, {'word': 'APPEARS', 'start_time': 1.2, 'end_time': 1.64}]
  514. ```"""
  515. # Convert inputs to python lists
  516. token_ids = to_py_obj(token_ids)
  517. return self._decode(
  518. token_ids=token_ids,
  519. skip_special_tokens=skip_special_tokens,
  520. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  521. output_char_offsets=output_char_offsets,
  522. output_word_offsets=output_word_offsets,
  523. **kwargs,
  524. )
  525. def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
  526. if not os.path.isdir(save_directory):
  527. logger.error(f"Vocabulary path ({save_directory}) should be a directory")
  528. return
  529. vocab_file = os.path.join(
  530. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
  531. )
  532. with open(vocab_file, "w", encoding="utf-8") as f:
  533. f.write(json.dumps(self.vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
  534. return (vocab_file,)
  535. class Wav2Vec2Tokenizer(PreTrainedTokenizer):
  536. """
  537. Constructs a Wav2Vec2 tokenizer.
  538. This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to
  539. the superclass for more information regarding such methods.
  540. Args:
  541. vocab_file (`str`):
  542. File containing the vocabulary.
  543. bos_token (`str`, *optional*, defaults to `"<s>"`):
  544. The beginning of sentence token.
  545. eos_token (`str`, *optional*, defaults to `"</s>"`):
  546. The end of sentence token.
  547. unk_token (`str`, *optional*, defaults to `"<unk>"`):
  548. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  549. token instead.
  550. pad_token (`str`, *optional*, defaults to `"<pad>"`):
  551. The token used for padding, for example when batching sequences of different lengths.
  552. word_delimiter_token (`str`, *optional*, defaults to `"|"`):
  553. The token used for defining the end of a word.
  554. do_lower_case (`bool`, *optional*, defaults to `False`):
  555. Whether or not to lowercase the output when decoding.
  556. do_normalize (`bool`, *optional*, defaults to `False`):
  557. Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
  558. improve the performance for some models, *e.g.*,
  559. [wav2vec2-lv60](https://huggingface.co/models?search=lv60).
  560. return_attention_mask (`bool`, *optional*, defaults to `False`):
  561. Whether or not [`~Wav2Vec2Tokenizer.__call__`] should return `attention_mask`.
  562. <Tip>
  563. Wav2Vec2 models that have set `config.feat_extract_norm == "group"`, such as
  564. [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), have **not** been trained using
  565. `attention_mask`. For such models, `input_values` should simply be padded with 0 and no `attention_mask`
  566. should be passed.
  567. For Wav2Vec2 models that have set `config.feat_extract_norm == "layer"`, such as
  568. [wav2vec2-lv60](https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self), `attention_mask` should be
  569. passed for batched inference.
  570. </Tip>
  571. **kwargs
  572. Additional keyword arguments passed along to [`PreTrainedTokenizer`]
  573. """
  574. vocab_files_names = VOCAB_FILES_NAMES
  575. pretrained_vocab_files_map = {
  576. "vocab_file": {
  577. "facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/vocab.json"
  578. },
  579. "tokenizer_config_file": {
  580. "facebook/wav2vec2-base-960h": (
  581. "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer.json"
  582. ),
  583. },
  584. }
  585. model_input_names = ["input_values", "attention_mask"]
  586. def __init__(
  587. self,
  588. vocab_file,
  589. bos_token="<s>",
  590. eos_token="</s>",
  591. unk_token="<unk>",
  592. pad_token="<pad>",
  593. word_delimiter_token="|",
  594. do_lower_case=False,
  595. do_normalize=False,
  596. return_attention_mask=False,
  597. **kwargs,
  598. ):
  599. warnings.warn(
  600. "The class `Wav2Vec2Tokenizer` is deprecated and will be removed in version 5 of Transformers. Please use"
  601. " `Wav2Vec2Processor` or `Wav2Vec2CTCTokenizer` instead.",
  602. FutureWarning,
  603. )
  604. self._word_delimiter_token = word_delimiter_token
  605. self.do_lower_case = do_lower_case
  606. self.return_attention_mask = return_attention_mask
  607. self.do_normalize = do_normalize
  608. with open(vocab_file, encoding="utf-8") as vocab_handle:
  609. self.encoder = json.load(vocab_handle)
  610. self.decoder = {v: k for k, v in self.encoder.items()}
  611. super().__init__(
  612. unk_token=unk_token,
  613. bos_token=bos_token,
  614. eos_token=eos_token,
  615. pad_token=pad_token,
  616. do_lower_case=do_lower_case,
  617. do_normalize=do_normalize,
  618. return_attention_mask=return_attention_mask,
  619. word_delimiter_token=word_delimiter_token,
  620. **kwargs,
  621. )
  622. @property
  623. def word_delimiter_token(self) -> str:
  624. """
  625. `str`: Padding token. Log an error if used while not having been set.
  626. """
  627. if self._word_delimiter_token is None and self.verbose:
  628. logger.error("Using word_delimiter_token, but it is not set yet.")
  629. return None
  630. return str(self._word_delimiter_token)
  631. @property
  632. def word_delimiter_token_id(self) -> Optional[int]:
  633. """
  634. `Optional[int]`: Id of the word_delimiter_token in the vocabulary. Returns `None` if the token has not been
  635. set.
  636. """
  637. if self._word_delimiter_token is None:
  638. return None
  639. return self.convert_tokens_to_ids(self.word_delimiter_token)
  640. @word_delimiter_token.setter
  641. def word_delimiter_token(self, value):
  642. self._word_delimiter_token = value
  643. @word_delimiter_token_id.setter
  644. def word_delimiter_token_id(self, value):
  645. self._word_delimiter_token = self.convert_tokens_to_ids(value)
  646. @add_end_docstrings(WAV2VEC2_KWARGS_DOCSTRING)
  647. def __call__(
  648. self,
  649. raw_speech: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
  650. padding: Union[bool, str, PaddingStrategy] = False,
  651. max_length: Optional[int] = None,
  652. pad_to_multiple_of: Optional[int] = None,
  653. padding_side: Optional[str] = None,
  654. return_tensors: Optional[Union[str, TensorType]] = None,
  655. verbose: bool = True,
  656. **kwargs,
  657. ) -> BatchEncoding:
  658. """
  659. Main method to tokenize and prepare for the model one or several sequence(s) or one or several pair(s) of
  660. sequences.
  661. Args:
  662. raw_speech (`np.ndarray`, `list[float]`, `list[np.ndarray]`, `list[list[float]]`):
  663. The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
  664. values, a list of numpy array or a list of list of float values. Must be mono channel audio, not
  665. stereo, i.e. single float per timestep.
  666. padding_side (`str`, *optional*):
  667. The side on which the model should have padding applied. Should be selected between ['right', 'left'].
  668. Default value is picked from the class attribute of the same name.
  669. """
  670. is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
  671. if is_batched_numpy and len(raw_speech.shape) > 2:
  672. raise ValueError(f"Only mono-channel audio is supported for input to {self}")
  673. is_batched = is_batched_numpy or (
  674. isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
  675. )
  676. # make sure input is in list format
  677. if is_batched and not isinstance(raw_speech[0], np.ndarray):
  678. raw_speech = [np.asarray(speech) for speech in raw_speech]
  679. elif not is_batched and not isinstance(raw_speech, np.ndarray):
  680. raw_speech = np.asarray(raw_speech)
  681. # always return batch
  682. if not is_batched:
  683. raw_speech = [raw_speech]
  684. # zero-mean and unit-variance normalization
  685. if self.do_normalize:
  686. raw_speech = [(x - np.mean(x)) / np.sqrt(np.var(x) + 1e-5) for x in raw_speech]
  687. # convert into correct format for padding
  688. encoded_inputs = BatchEncoding({"input_values": raw_speech})
  689. padded_inputs = self.pad(
  690. encoded_inputs,
  691. padding=padding,
  692. max_length=max_length,
  693. pad_to_multiple_of=pad_to_multiple_of,
  694. padding_side=padding_side,
  695. return_attention_mask=self.return_attention_mask,
  696. return_tensors=return_tensors,
  697. verbose=verbose,
  698. )
  699. return padded_inputs
  700. @property
  701. def vocab_size(self) -> int:
  702. return len(self.decoder)
  703. def get_vocab(self) -> dict:
  704. return dict(self.encoder, **self.added_tokens_encoder)
  705. def _convert_token_to_id(self, token: str) -> int:
  706. """Converts a token (str) in an index (integer) using the vocab."""
  707. return self.encoder.get(token, self.encoder.get(self.unk_token))
  708. def _convert_id_to_token(self, index: int) -> str:
  709. """Converts an index (integer) in a token (str) using the vocab."""
  710. result = self.decoder.get(index, self.unk_token)
  711. return result
  712. def convert_tokens_to_string(self, tokens: list[str]) -> str:
  713. """
  714. Converts a connectionist-temporal-classification (CTC) output tokens into a single string.
  715. """
  716. # group same tokens into non-repeating tokens in CTC style decoding
  717. grouped_tokens = [token_group[0] for token_group in groupby(tokens)]
  718. # filter self.pad_token which is used as CTC-blank token
  719. filtered_tokens = list(filter(lambda token: token != self.pad_token, grouped_tokens))
  720. # replace delimiter token
  721. string = "".join([" " if token == self.word_delimiter_token else token for token in filtered_tokens]).strip()
  722. if self.do_lower_case:
  723. string = string.lower()
  724. return string
  725. def _decode(
  726. self,
  727. token_ids: list[int],
  728. skip_special_tokens: bool = False,
  729. clean_up_tokenization_spaces: Optional[bool] = None,
  730. **kwargs,
  731. ) -> str:
  732. """
  733. special _decode function is needed for Wav2Vec2Tokenizer because added tokens should be treated exactly the
  734. same as tokens of the base vocabulary and therefore the function `convert_tokens_to_string` has to be called on
  735. the whole token list and not individually on added tokens
  736. """
  737. filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
  738. result = []
  739. for token in filtered_tokens:
  740. if skip_special_tokens and (
  741. token in self.all_special_ids or (token != self.pad_token and token in self.all_special_tokens)
  742. ):
  743. continue
  744. result.append(token)
  745. text = self.convert_tokens_to_string(result)
  746. clean_up_tokenization_spaces = (
  747. clean_up_tokenization_spaces
  748. if clean_up_tokenization_spaces is not None
  749. else self.clean_up_tokenization_spaces
  750. )
  751. if clean_up_tokenization_spaces:
  752. clean_text = self.clean_up_tokenization(text)
  753. return clean_text
  754. else:
  755. return text
  756. def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
  757. if not os.path.isdir(save_directory):
  758. logger.error(f"Vocabulary path ({save_directory}) should be a directory")
  759. return
  760. vocab_file = os.path.join(
  761. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
  762. )
  763. with open(vocab_file, "w", encoding="utf-8") as f:
  764. f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
  765. return (vocab_file,)
  766. __all__ = ["Wav2Vec2CTCTokenizer", "Wav2Vec2Tokenizer"]