tokenization_whisper.py 57 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420
  1. # coding=utf-8
  2. # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Tokenization classes for Whisper."""
  16. import json
  17. import os
  18. import warnings
  19. from functools import lru_cache
  20. from typing import Optional, Union
  21. import numpy as np
  22. import regex as re
  23. from ...tokenization_utils import AddedToken, PreTrainedTokenizer
  24. from ...utils import logging
  25. from .english_normalizer import BasicTextNormalizer, EnglishTextNormalizer
  26. VOCAB_FILES_NAMES = {
  27. "vocab_file": "vocab.json",
  28. "tokenizer_file": "tokenizer.json",
  29. "merges_file": "merges.txt",
  30. "normalizer_file": "normalizer.json",
  31. }
  32. MAX_MODEL_INPUT_SIZES = {
  33. "openai/whisper-base": 448,
  34. }
  35. # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
  36. def bytes_to_unicode():
  37. """
  38. Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
  39. characters the bpe code barfs on.
  40. The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
  41. if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
  42. decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
  43. tables between utf-8 bytes and unicode strings.
  44. """
  45. bs = (
  46. list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
  47. )
  48. cs = bs[:]
  49. n = 0
  50. for b in range(2**8):
  51. if b not in bs:
  52. bs.append(b)
  53. cs.append(2**8 + n)
  54. n += 1
  55. cs = [chr(n) for n in cs]
  56. return dict(zip(bs, cs))
  57. logger = logging.get_logger(__name__)
  58. # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs
  59. def get_pairs(word):
  60. """
  61. Return set of symbol pairs in a word.
  62. Word is represented as tuple of symbols (symbols being variable-length strings).
  63. """
  64. pairs = set()
  65. prev_char = word[0]
  66. for char in word[1:]:
  67. pairs.add((prev_char, char))
  68. prev_char = char
  69. return pairs
  70. LANGUAGES = {
  71. "en": "english",
  72. "zh": "chinese",
  73. "de": "german",
  74. "es": "spanish",
  75. "ru": "russian",
  76. "ko": "korean",
  77. "fr": "french",
  78. "ja": "japanese",
  79. "pt": "portuguese",
  80. "tr": "turkish",
  81. "pl": "polish",
  82. "ca": "catalan",
  83. "nl": "dutch",
  84. "ar": "arabic",
  85. "sv": "swedish",
  86. "it": "italian",
  87. "id": "indonesian",
  88. "hi": "hindi",
  89. "fi": "finnish",
  90. "vi": "vietnamese",
  91. "he": "hebrew",
  92. "uk": "ukrainian",
  93. "el": "greek",
  94. "ms": "malay",
  95. "cs": "czech",
  96. "ro": "romanian",
  97. "da": "danish",
  98. "hu": "hungarian",
  99. "ta": "tamil",
  100. "no": "norwegian",
  101. "th": "thai",
  102. "ur": "urdu",
  103. "hr": "croatian",
  104. "bg": "bulgarian",
  105. "lt": "lithuanian",
  106. "la": "latin",
  107. "mi": "maori",
  108. "ml": "malayalam",
  109. "cy": "welsh",
  110. "sk": "slovak",
  111. "te": "telugu",
  112. "fa": "persian",
  113. "lv": "latvian",
  114. "bn": "bengali",
  115. "sr": "serbian",
  116. "az": "azerbaijani",
  117. "sl": "slovenian",
  118. "kn": "kannada",
  119. "et": "estonian",
  120. "mk": "macedonian",
  121. "br": "breton",
  122. "eu": "basque",
  123. "is": "icelandic",
  124. "hy": "armenian",
  125. "ne": "nepali",
  126. "mn": "mongolian",
  127. "bs": "bosnian",
  128. "kk": "kazakh",
  129. "sq": "albanian",
  130. "sw": "swahili",
  131. "gl": "galician",
  132. "mr": "marathi",
  133. "pa": "punjabi",
  134. "si": "sinhala",
  135. "km": "khmer",
  136. "sn": "shona",
  137. "yo": "yoruba",
  138. "so": "somali",
  139. "af": "afrikaans",
  140. "oc": "occitan",
  141. "ka": "georgian",
  142. "be": "belarusian",
  143. "tg": "tajik",
  144. "sd": "sindhi",
  145. "gu": "gujarati",
  146. "am": "amharic",
  147. "yi": "yiddish",
  148. "lo": "lao",
  149. "uz": "uzbek",
  150. "fo": "faroese",
  151. "ht": "haitian creole",
  152. "ps": "pashto",
  153. "tk": "turkmen",
  154. "nn": "nynorsk",
  155. "mt": "maltese",
  156. "sa": "sanskrit",
  157. "lb": "luxembourgish",
  158. "my": "myanmar",
  159. "bo": "tibetan",
  160. "tl": "tagalog",
  161. "mg": "malagasy",
  162. "as": "assamese",
  163. "tt": "tatar",
  164. "haw": "hawaiian",
  165. "ln": "lingala",
  166. "ha": "hausa",
  167. "ba": "bashkir",
  168. "jw": "javanese",
  169. "su": "sundanese",
  170. "yue": "cantonese",
  171. }
  172. # language code lookup by name, with a few language aliases
  173. TO_LANGUAGE_CODE = {
  174. **{language: code for code, language in LANGUAGES.items()},
  175. "burmese": "my",
  176. "valencian": "ca",
  177. "flemish": "nl",
  178. "haitian": "ht",
  179. "letzeburgesch": "lb",
  180. "pushto": "ps",
  181. "panjabi": "pa",
  182. "moldavian": "ro",
  183. "moldovan": "ro",
  184. "sinhalese": "si",
  185. "castilian": "es",
  186. "mandarin": "zh",
  187. }
  188. TASK_IDS = ["translate", "transcribe"]
  189. class WhisperTokenizer(PreTrainedTokenizer):
  190. """
  191. Construct a Whisper tokenizer.
  192. This tokenizer inherits from [`PreTrainedTokenizer`] which contains some of the main methods. Users should refer to
  193. the superclass for more information regarding such methods.
  194. Args:
  195. vocab_file (`str`):
  196. Path to the vocabulary file.
  197. merges_file (`str`):
  198. Path to the merges file.
  199. normalizer_file (`str`, *optional*):
  200. Path to the normalizer_file file.
  201. errors (`str`, *optional*, defaults to `"replace"`):
  202. Paradigm to follow when decoding bytes to UTF-8. See
  203. [bytes.decode](https://docs.python.org/3/library/stdtypes.html#bytes.decode) for more information.
  204. unk_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
  205. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  206. token instead.
  207. bos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
  208. The beginning of sequence token. The `decoder_start_token_id` is used to set the first token as
  209. `"<|startoftranscript|>"` when generating.
  210. eos_token (`str`, *optional*, defaults to `"<|endoftext|>"`):
  211. The end of sequence token.
  212. pad_token (`str`, *optional*):
  213. The token used for padding, for example when batching sequences of different lengths.
  214. add_prefix_space (`bool`, *optional*, defaults to `False`):
  215. Whether or not to add an initial space to the input. This allows to treat the leading word just as any
  216. other word.
  217. language (`str`, *optional*):
  218. The language of the transcription text. The corresponding language id token is appended to the start of the
  219. sequence for multilingual speech recognition and speech translation tasks, e.g. for Spanish the token
  220. `"<|es|>"` is appended to the start of sequence. This should be used for multilingual fine-tuning only.
  221. task (`str`, *optional*):
  222. Task identifier to append at the start of sequence (if any). This should be used for mulitlingual
  223. fine-tuning, with `"transcribe"` for speech recognition and `"translate"` for speech translation.
  224. predict_timestamps (`bool`, *optional*, defaults to `False`):
  225. Whether to omit the `<|notimestamps|>` token at the start of the sequence.
  226. """
  227. vocab_files_names = VOCAB_FILES_NAMES
  228. model_input_names = ["input_ids", "attention_mask"]
  229. def __init__(
  230. self,
  231. vocab_file,
  232. merges_file,
  233. normalizer_file=None,
  234. errors="replace",
  235. unk_token="<|endoftext|>",
  236. bos_token="<|endoftext|>",
  237. eos_token="<|endoftext|>",
  238. pad_token=None,
  239. add_prefix_space=False,
  240. language=None,
  241. task=None,
  242. predict_timestamps=False,
  243. **kwargs,
  244. ):
  245. bos_token = (
  246. AddedToken(bos_token, lstrip=False, rstrip=False, normalized=False, special=True)
  247. if isinstance(bos_token, str)
  248. else bos_token
  249. )
  250. eos_token = (
  251. AddedToken(eos_token, lstrip=False, rstrip=False, normalized=False, special=True)
  252. if isinstance(eos_token, str)
  253. else eos_token
  254. )
  255. unk_token = (
  256. AddedToken(unk_token, lstrip=False, rstrip=False, normalized=False, special=True)
  257. if isinstance(unk_token, str)
  258. else unk_token
  259. )
  260. pad_token = (
  261. AddedToken(pad_token, lstrip=False, rstrip=False, normalized=False, special=True)
  262. if isinstance(pad_token, str)
  263. else pad_token
  264. )
  265. with open(vocab_file, encoding="utf-8") as vocab_handle:
  266. self.encoder = json.load(vocab_handle)
  267. self.decoder = {v: k for k, v in self.encoder.items()}
  268. self.errors = errors # how to handle errors in decoding
  269. self.byte_encoder = bytes_to_unicode()
  270. self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
  271. with open(merges_file, encoding="utf-8") as merges_handle:
  272. bpe_merges = merges_handle.read().split("\n")[1:-1]
  273. bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
  274. self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
  275. self.cache = {}
  276. self.add_prefix_space = add_prefix_space
  277. if normalizer_file is not None:
  278. with open(normalizer_file, encoding="utf-8") as vocab_handle:
  279. self.english_spelling_normalizer = json.load(vocab_handle)
  280. else:
  281. self.english_spelling_normalizer = None
  282. # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
  283. self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
  284. self.timestamp_pat = re.compile(r"<\|(\d+\.\d+)\|>")
  285. self.language = language
  286. super().__init__(
  287. errors=errors,
  288. unk_token=unk_token,
  289. bos_token=bos_token,
  290. eos_token=eos_token,
  291. pad_token=pad_token,
  292. add_prefix_space=add_prefix_space,
  293. **kwargs,
  294. )
  295. self.task = task
  296. self.predict_timestamps = predict_timestamps
  297. @property
  298. def vocab_size(self) -> int:
  299. return len(self.encoder)
  300. def get_vocab(self):
  301. vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
  302. vocab.update(self.added_tokens_encoder)
  303. return vocab
  304. # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe with GPT2 -> Whisper
  305. def bpe(self, token):
  306. if token in self.cache:
  307. return self.cache[token]
  308. word = tuple(token)
  309. pairs = get_pairs(word)
  310. if not pairs:
  311. return token
  312. while True:
  313. bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
  314. if bigram not in self.bpe_ranks:
  315. break
  316. first, second = bigram
  317. new_word = []
  318. i = 0
  319. while i < len(word):
  320. try:
  321. j = word.index(first, i)
  322. except ValueError:
  323. new_word.extend(word[i:])
  324. break
  325. else:
  326. new_word.extend(word[i:j])
  327. i = j
  328. if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
  329. new_word.append(first + second)
  330. i += 2
  331. else:
  332. new_word.append(word[i])
  333. i += 1
  334. new_word = tuple(new_word)
  335. word = new_word
  336. if len(word) == 1:
  337. break
  338. else:
  339. pairs = get_pairs(word)
  340. word = " ".join(word)
  341. self.cache[token] = word
  342. return word
  343. def set_prefix_tokens(
  344. self, language: Optional[str] = None, task: Optional[str] = None, predict_timestamps: Optional[bool] = None
  345. ):
  346. """
  347. Override the prefix tokens appended to the start of the label sequence. This method can be used standalone to
  348. update the prefix tokens as required when fine-tuning. Example:
  349. ```python
  350. >>> # instantiate the tokenizer and set the prefix token to Spanish
  351. >>> tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny", language="spanish")
  352. >>> # now switch the prefix token from Spanish to French
  353. >>> tokenizer.set_prefix_tokens(language="french")
  354. ```
  355. Args:
  356. language (`str`, *optional*, defaults to `None`):
  357. The language of the transcription text.
  358. task (`str`, *optional*, defaults to `None`):
  359. Task identifier to append at the start of sequence (if any).
  360. predict_timestamps (`bool`, *optional*, defaults to `None`):
  361. Whether to omit the `<|notimestamps|>` token at the start of the sequence.
  362. """
  363. self.language = language if language is not None else self.language
  364. self.task = task if task is not None else self.task
  365. self.predict_timestamps = predict_timestamps if predict_timestamps is not None else self.predict_timestamps
  366. @property
  367. def prefix_tokens(self) -> list[int]:
  368. bos_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
  369. translate_token_id = self.convert_tokens_to_ids("<|translate|>")
  370. transcribe_token_id = self.convert_tokens_to_ids("<|transcribe|>")
  371. notimestamps_token_id = self.convert_tokens_to_ids("<|notimestamps|>")
  372. langs = tuple(LANGUAGES.keys())
  373. if self.language is not None:
  374. self.language = self.language.lower()
  375. if self.language in TO_LANGUAGE_CODE:
  376. language_id = TO_LANGUAGE_CODE[self.language]
  377. elif self.language in TO_LANGUAGE_CODE.values():
  378. language_id = self.language
  379. else:
  380. is_language_code = len(self.language) == 2
  381. raise ValueError(
  382. f"Unsupported language: {self.language}. Language should be one of:"
  383. f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
  384. )
  385. if self.task is not None:
  386. if self.task not in TASK_IDS:
  387. raise ValueError(f"Unsupported task: {self.task}. Task should be in: {TASK_IDS}")
  388. bos_sequence = [bos_token_id]
  389. if self.language is not None:
  390. bos_sequence.append(bos_token_id + 1 + langs.index(language_id))
  391. if self.task is not None:
  392. bos_sequence.append(transcribe_token_id if self.task == "transcribe" else translate_token_id)
  393. if not self.predict_timestamps:
  394. bos_sequence.append(notimestamps_token_id)
  395. return bos_sequence
  396. # Copied from transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.build_inputs_with_special_tokens
  397. def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> list[int]:
  398. """Build model inputs from a sequence by appending eos_token_id."""
  399. if token_ids_1 is None:
  400. return self.prefix_tokens + token_ids_0 + [self.eos_token_id]
  401. # We don't expect to process pairs, but leave the pair logic for API consistency
  402. return self.prefix_tokens + token_ids_0 + token_ids_1 + [self.eos_token_id]
  403. # Copied from transformers.models.speech_to_text.tokenization_speech_to_text.Speech2TextTokenizer.get_special_tokens_mask
  404. def get_special_tokens_mask(
  405. self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
  406. ) -> list[int]:
  407. """
  408. Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
  409. special tokens using the tokenizer `prepare_for_model` method.
  410. Args:
  411. token_ids_0 (`list[int]`):
  412. List of IDs.
  413. token_ids_1 (`list[int]`, *optional*):
  414. Optional second list of IDs for sequence pairs.
  415. already_has_special_tokens (`bool`, *optional*, defaults to `False`):
  416. Whether or not the token list is already formatted with special tokens for the model.
  417. Returns:
  418. `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
  419. """
  420. if already_has_special_tokens:
  421. return super().get_special_tokens_mask(
  422. token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
  423. )
  424. prefix_ones = [1] * len(self.prefix_tokens)
  425. suffix_ones = [1]
  426. if token_ids_1 is None:
  427. return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
  428. return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
  429. # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize with GPT2 -> Whisper
  430. def _tokenize(self, text):
  431. """Tokenize a string."""
  432. bpe_tokens = []
  433. for token in re.findall(self.pat, text):
  434. token = "".join(
  435. self.byte_encoder[b] for b in token.encode("utf-8")
  436. ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
  437. bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
  438. return bpe_tokens
  439. # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id with GPT2 -> Whisper
  440. def _convert_token_to_id(self, token):
  441. """Converts a token (str) in an id using the vocab."""
  442. return self.encoder.get(token, self.encoder.get(self.unk_token))
  443. def _convert_id_to_token(self, index):
  444. """
  445. Converts an index (integer) in a token (str) using the vocab. Whisper's base tokenizer always decodes OOV
  446. tokens as "", thus we do not use the `unk_token` here.
  447. """
  448. return self.decoder.get(index, "")
  449. def _normalize(self, text):
  450. warnings.warn(
  451. "The private method `_normalize` is deprecated and will be removed in v5 of Transformers."
  452. "You can normalize an input string using the Whisper English normalizer using the `normalize` method."
  453. )
  454. return self.normalize(text)
  455. def _basic_normalize(self, text, remove_diacritics=False):
  456. warnings.warn(
  457. "The private method `_basic_normalize` is deprecated and will be removed in v5 of Transformers."
  458. "You can normalize an input string using the Whisper basic normalizer using the `basic_normalize` method."
  459. )
  460. return self.basic_normalize(text, remove_diacritics=remove_diacritics)
  461. def normalize(self, text):
  462. """
  463. Normalize a given string using the `EnglishTextNormalizer` class, which performs commons transformation on
  464. english text.
  465. """
  466. normalizer = EnglishTextNormalizer(self.english_spelling_normalizer)
  467. return normalizer(text)
  468. @staticmethod
  469. def basic_normalize(text, remove_diacritics=False):
  470. """
  471. Normalize a given string using the `BasicTextNormalizer` class, which performs commons transformation on
  472. multilingual text.
  473. """
  474. normalizer = BasicTextNormalizer(remove_diacritics=remove_diacritics)
  475. return normalizer(text)
  476. def _decode_with_timestamps(
  477. self, token_ids, skip_special_tokens=False, time_precision=0.02, segment_size=1500
  478. ) -> str:
  479. """
  480. Timestamp tokens are above the special tokens' id range and are ignored by `decode()`. This method decodes
  481. given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
  482. """
  483. timestamp_begin = self.all_special_ids[-1] + 1
  484. outputs = [[]]
  485. cur_max_timestamp = 0.0
  486. prev_segments_len = 0.0
  487. penultimate_timestamp = 0.0
  488. for i, token in enumerate(token_ids):
  489. if token >= timestamp_begin:
  490. timestamp = float((token - timestamp_begin) * time_precision)
  491. if timestamp < cur_max_timestamp:
  492. # next segment has started
  493. last_was_single_ending = i >= 2 and not (
  494. token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
  495. )
  496. if last_was_single_ending:
  497. prev_segments_len += time_precision * segment_size
  498. else:
  499. cur_max_timestamp = penultimate_timestamp
  500. prev_segments_len += penultimate_timestamp
  501. outputs = outputs[:-2]
  502. penultimate_timestamp = cur_max_timestamp
  503. cur_max_timestamp = timestamp
  504. outputs.append(f"<|{(timestamp + prev_segments_len):.2f}|>")
  505. outputs.append([])
  506. else:
  507. outputs[-1].append(token)
  508. outputs = [
  509. s if isinstance(s, str) else self.decode(s, skip_special_tokens=skip_special_tokens) for s in outputs
  510. ]
  511. return "".join(outputs)
  512. def _compute_offsets(self, token_ids, time_precision=0.02, segment_size=1500):
  513. """
  514. Compute offsets for a given tokenized input
  515. Args:
  516. token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`):
  517. List of tokenized input ids. Can be obtained using the `__call__` method.
  518. time_precision (`float`, *optional*, defaults to 0.02):
  519. The time ratio to convert from token to time.
  520. segment_size (`int`, *optional*, defaults to 1500):
  521. The number of features in the input mel spectrogram.
  522. """
  523. offsets = []
  524. # ensure torch tensor of token ids is placed on cpu
  525. if "torch" in str(type(token_ids)) and (hasattr(token_ids, "cpu") and callable(token_ids.cpu)):
  526. token_ids = token_ids.cpu()
  527. token_ids = np.array(token_ids)
  528. if token_ids.shape[0] > 1 and len(token_ids.shape) > 1:
  529. raise ValueError("Can only process a single input at a time")
  530. timestamp_begin = self.all_special_ids[-1] + 1
  531. timestamp_tokens = token_ids >= timestamp_begin
  532. consecutive = np.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0] + 1
  533. if consecutive.shape[0] == 0 and timestamp_tokens.sum() <= 1:
  534. # either there are no timestamps or there are no consecutive ones
  535. return []
  536. elif np.where(timestamp_tokens)[0][-1] + 1 not in consecutive:
  537. # we add the final timestamp if it is not already in the list
  538. consecutive = np.append(consecutive, np.where(timestamp_tokens)[0][-1] + 1)
  539. last_slice = np.where(timestamp_tokens)[0][0]
  540. cur_max_timestamp = 0
  541. prev_segments_len = 0
  542. for current_slice in consecutive:
  543. sliced_tokens = token_ids[last_slice:current_slice]
  544. if len(sliced_tokens) > 1:
  545. start_timestamp_position = sliced_tokens[0].item() - timestamp_begin
  546. end_timestamp_position = sliced_tokens[-1].item() - timestamp_begin
  547. if start_timestamp_position < cur_max_timestamp:
  548. # next segment has started
  549. is_single_ending = last_slice >= 2 and not (
  550. token_ids[last_slice - 2] >= timestamp_begin and token_ids[last_slice - 1] >= timestamp_begin
  551. )
  552. if is_single_ending:
  553. prev_segments_len += segment_size
  554. else:
  555. prev_segments_len += cur_max_timestamp
  556. cur_max_timestamp = end_timestamp_position
  557. # strip timestamp tokens from the text output
  558. sliced_tokens = self._preprocess_token_ids(sliced_tokens)
  559. text = self._decode(sliced_tokens)
  560. text = self._filter_timestamp_ids(text)
  561. offsets.append(
  562. {
  563. "text": text,
  564. "timestamp": (
  565. start_timestamp_position * time_precision + prev_segments_len * time_precision,
  566. end_timestamp_position * time_precision + prev_segments_len * time_precision,
  567. ),
  568. }
  569. )
  570. last_slice = current_slice
  571. return offsets
  572. @lru_cache
  573. def timestamp_ids(self, time_precision=0.02):
  574. """
  575. Compute the timestamp token ids for a given precision and save to least-recently used (LRU) cache.
  576. Args:
  577. time_precision (`float`, *optional*, defaults to 0.02):
  578. The time ratio to convert from token to time.
  579. """
  580. return self.convert_tokens_to_ids([("<|%.2f|>" % (i * time_precision)) for i in range(1500 + 1)])
  581. def _preprocess_token_ids(self, token_ids, skip_special_tokens: bool = False):
  582. """
  583. Pre-process the token ids for decoding by removing the prompt tokens ids and timestamp token ids.
  584. Args:
  585. token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`):
  586. List of tokenized input ids. Typically, obtained using the `__call__` method of the tokenizer.
  587. skip_special_tokens (`bool`, *optional*, defaults to `False`):
  588. Whether or not to remove special tokens from the token ids. If `True`, the prompt token ids will be
  589. removed.
  590. """
  591. if skip_special_tokens:
  592. prompt_token_id = self.convert_tokens_to_ids("<|startofprev|>")
  593. decoder_start_token_id = self.convert_tokens_to_ids("<|startoftranscript|>")
  594. token_ids = self._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
  595. return token_ids
  596. def _filter_timestamp_ids(self, token_ids):
  597. return re.sub(self.timestamp_pat, "", token_ids)
  598. def decode(
  599. self,
  600. token_ids,
  601. skip_special_tokens: bool = False,
  602. clean_up_tokenization_spaces: Optional[bool] = None,
  603. output_offsets: bool = False,
  604. time_precision: float = 0.02,
  605. decode_with_timestamps: bool = False,
  606. normalize: bool = False,
  607. basic_normalize: bool = False,
  608. remove_diacritics: bool = False,
  609. **kwargs,
  610. ) -> str:
  611. """
  612. Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
  613. tokens and clean up tokenization spaces.
  614. Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
  615. Args:
  616. token_ids (`Union[int, list[int], np.ndarray, torch.Tensor, tf.Tensor]`):
  617. List of tokenized input ids. Can be obtained using the `__call__` method.
  618. skip_special_tokens (`bool`, *optional*, defaults to `False`):
  619. Whether or not to remove special tokens in the decoding. Will remove the previous tokens (pre-prompt)
  620. if present.
  621. clean_up_tokenization_spaces (`bool`, *optional*):
  622. Whether or not to clean up the tokenization spaces. If `None`, will default to
  623. `self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
  624. output_offsets (`bool`, *optional*, defaults to `False`):
  625. Whether or not to output the offsets of the tokens. This should only be set if the model predicted
  626. timestamps. If there are previous tokens (pre-prompt) to decode, they will only appear in the decoded
  627. text if they contain timestamp tokens.
  628. time_precision (`float`, *optional*, defaults to 0.02):
  629. The time ratio to convert from token to time.
  630. decode_with_timestamps (`bool`, *optional*, defaults to `False`):
  631. Whether or not to decode with timestamps included in the raw text.
  632. normalize (`bool`, *optional*, defaults to `False`):
  633. Whether or not to apply the English text normalizer to the decoded text. Only applicable when the
  634. target text is in English. Otherwise, the basic text normalizer should be applied.
  635. basic_normalize (`bool`, *optional*, defaults to `False`):
  636. Whether or not to apply the Basic text normalizer to the decoded text. Applicable to multilingual
  637. target text.
  638. remove_diacritics (`bool`, *optional*, defaults to `False`):
  639. Whether or not to remove diacritics when applying the Basic text normalizer. Removing diacritics may
  640. destroy information in the decoded text, hence it should be used with caution.
  641. kwargs (additional keyword arguments, *optional*):
  642. Will be passed to the underlying model specific decode method.
  643. Returns:
  644. `str`: The decoded sentence.
  645. """
  646. filtered_ids = self._preprocess_token_ids(
  647. token_ids,
  648. skip_special_tokens=skip_special_tokens,
  649. )
  650. text = super().decode(
  651. filtered_ids,
  652. skip_special_tokens=skip_special_tokens,
  653. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  654. normalize=normalize,
  655. basic_normalize=basic_normalize,
  656. remove_diacritics=remove_diacritics,
  657. **kwargs,
  658. )
  659. if decode_with_timestamps:
  660. # legacy method to decode timestamps when not included in the tokenizer vocabulary
  661. text = self._decode_with_timestamps(
  662. filtered_ids, time_precision=time_precision, skip_special_tokens=skip_special_tokens
  663. )
  664. else:
  665. text = self._filter_timestamp_ids(text)
  666. # retrieve offsets
  667. if output_offsets:
  668. offsets = self._compute_offsets(token_ids, time_precision=time_precision)
  669. return {"text": text, "offsets": offsets}
  670. return text
  671. def _decode(
  672. self,
  673. token_ids: Union[int, list[int]],
  674. skip_special_tokens: bool = False,
  675. normalize: bool = False,
  676. basic_normalize: bool = False,
  677. remove_diacritics: bool = False,
  678. **kwargs,
  679. ) -> str:
  680. self._decode_use_source_tokenizer = kwargs.pop("use_source_tokenizer", False)
  681. filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
  682. # To avoid mixing byte-level and unicode for byte-level BPT
  683. # we need to build string separately for added tokens and byte-level tokens
  684. # cf. https://github.com/huggingface/transformers/issues/1133
  685. sub_texts = []
  686. current_sub_text = []
  687. for token in filtered_tokens:
  688. if skip_special_tokens and token in self.all_special_ids:
  689. continue
  690. if token in self.added_tokens_encoder:
  691. if current_sub_text:
  692. sub_texts.append(self.convert_tokens_to_string(current_sub_text))
  693. current_sub_text = []
  694. sub_texts.append(token)
  695. else:
  696. current_sub_text.append(token)
  697. if current_sub_text:
  698. sub_texts.append(self.convert_tokens_to_string(current_sub_text))
  699. text = "".join(sub_texts)
  700. if normalize:
  701. clean_text = self.normalize(text)
  702. return clean_text
  703. elif basic_normalize:
  704. clean_text = self.basic_normalize(text, remove_diacritics=remove_diacritics)
  705. return clean_text
  706. else:
  707. return text
  708. # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.convert_tokens_to_string with GPT2 -> Whisper
  709. def convert_tokens_to_string(self, tokens):
  710. """Converts a sequence of tokens (string) in a single string."""
  711. text = "".join(tokens)
  712. text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
  713. return text
  714. def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
  715. if not os.path.isdir(save_directory):
  716. logger.error(f"Vocabulary path ({save_directory}) should be a directory")
  717. return
  718. vocab_file = os.path.join(
  719. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
  720. )
  721. merge_file = os.path.join(
  722. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
  723. )
  724. normalizer_file = os.path.join(
  725. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["normalizer_file"]
  726. )
  727. with open(vocab_file, "w", encoding="utf-8") as f:
  728. f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
  729. index = 0
  730. with open(merge_file, "w", encoding="utf-8") as writer:
  731. writer.write("#version: 0.2\n")
  732. for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
  733. if index != token_index:
  734. logger.warning(
  735. f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
  736. " Please check that the tokenizer is not corrupted!"
  737. )
  738. index = token_index
  739. writer.write(" ".join(bpe_tokens) + "\n")
  740. index += 1
  741. if self.english_spelling_normalizer is not None:
  742. with open(normalizer_file, "w", encoding="utf-8") as f:
  743. f.write(
  744. json.dumps(self.english_spelling_normalizer, indent=2, sort_keys=True, ensure_ascii=False) + "\n"
  745. )
  746. return vocab_file, merge_file, normalizer_file
  747. # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.prepare_for_tokenization with GPT2 -> Whisper
  748. def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
  749. add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
  750. if is_split_into_words or add_prefix_space:
  751. text = " " + text
  752. return (text, kwargs)
  753. def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True):
  754. self.set_prefix_tokens(task=task, language=language, predict_timestamps=not no_timestamps)
  755. # prefix tokens are of the form: <|startoftranscript|> <|lang_id|> <|task|> <|notimestamps|>
  756. # we don't want to force the bos token at position 1, as this is the starting token
  757. # when we generate, so we slice the prefix tokens to: <|lang_id|> <|task|> <|notimestamps|>
  758. # to get the forced tokens
  759. forced_tokens = self.prefix_tokens[1:]
  760. forced_decoder_ids = [(rank + 1, token) for rank, token in enumerate(forced_tokens)]
  761. return forced_decoder_ids
  762. def _decode_asr(self, model_outputs, *, return_timestamps, return_language, time_precision):
  763. return _decode_asr(
  764. self,
  765. model_outputs,
  766. return_timestamps=return_timestamps,
  767. return_language=return_language,
  768. time_precision=time_precision,
  769. )
  770. def get_prompt_ids(self, text: str, return_tensors="np"):
  771. """Converts prompt text to IDs that can be passed to [`~WhisperForConditionalGeneration.generate`]."""
  772. batch_encoding = self("<|startofprev|>", " " + text.strip(), add_special_tokens=False)
  773. # Check for special tokens
  774. prompt_text_ids = batch_encoding["input_ids"][1:]
  775. special_token_id = next((x for x in prompt_text_ids if x >= self.all_special_ids[0]), None)
  776. if special_token_id is not None:
  777. token = self.convert_ids_to_tokens(special_token_id)
  778. raise ValueError(f"Encountered text in the prompt corresponding to disallowed special token: {token}.")
  779. batch_encoding.convert_to_tensors(tensor_type=return_tensors)
  780. return batch_encoding["input_ids"]
  781. def _strip_prompt(self, token_ids: list[int], prompt_token_id: int, decoder_start_token_id: int):
  782. if not isinstance(token_ids, list):
  783. token_ids = self._convert_to_list(token_ids)
  784. # handle case of empty token_ids for decoding with timestamps.
  785. # at this point token_ids is a list, so it is safe to use if not check.
  786. if not token_ids:
  787. return token_ids
  788. has_prompt = token_ids[0] == prompt_token_id
  789. if has_prompt:
  790. if decoder_start_token_id in token_ids:
  791. return token_ids[token_ids.index(decoder_start_token_id) :]
  792. else:
  793. return []
  794. return token_ids
  795. @staticmethod
  796. def _convert_to_list(token_ids):
  797. # convert type to ndarray if necessary
  798. if hasattr(token_ids, "numpy"):
  799. if "torch" in str(type(token_ids)):
  800. token_ids = token_ids.cpu().numpy()
  801. elif "tensorflow" in str(type(token_ids)):
  802. token_ids = token_ids.numpy()
  803. elif "jaxlib" in str(type(token_ids)):
  804. token_ids = token_ids.tolist()
  805. # now the token ids are either a numpy array, or a list of lists
  806. if isinstance(token_ids, np.ndarray):
  807. token_ids = token_ids.tolist()
  808. return token_ids
  809. def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, time_precision, segment_size=1500):
  810. """
  811. Internal method meant to only be used by asr pipeline. Handles all the little quirks specific to whisper to handle
  812. the various options not allowed in other seq2seq models
  813. """
  814. # =========== Overview ============
  815. # - iterate over all outputs
  816. # - all tokens within output
  817. # - Each token can be
  818. # - language token
  819. # - special token
  820. # - timestamp token
  821. # - text token
  822. # - We accumulate the text tokens.
  823. # - We split on end timestamps
  824. # - Lots of complexity comes from stride and timestamps
  825. last_language = None
  826. def new_chunk():
  827. return {"language": last_language, "timestamp": [None, None], "text": ""}
  828. # Welcome to the state machine !
  829. chunks = []
  830. chunk = new_chunk()
  831. time_offset = 0.0
  832. timestamp_begin = tokenizer.convert_tokens_to_ids("<|notimestamps|>") + 1
  833. previous_tokens = []
  834. previous_token_timestamps = []
  835. skip = False
  836. right_stride_start = None
  837. all_special_ids = set(tokenizer.all_special_ids)
  838. prompt_token_id = tokenizer.convert_tokens_to_ids("<|startofprev|>")
  839. decoder_start_token_id = tokenizer.convert_tokens_to_ids("<|startoftranscript|>")
  840. # - iterate over all outputs
  841. for chunk_id, output in enumerate(model_outputs):
  842. # We can drop everything to Python list, it's going to make
  843. # our lives easier
  844. token_ids = output["tokens"][0].tolist()
  845. # (possibly) remove the prompt from the token ids
  846. token_ids = tokenizer._strip_prompt(token_ids, prompt_token_id, decoder_start_token_id)
  847. if return_timestamps == "word":
  848. token_timestamps = output["token_timestamps"][0].tolist()
  849. # Those keep track of timestamps within strides
  850. # Which need to be skipped and resolve all tokens in a single
  851. # chunk.
  852. last_timestamp = None
  853. first_timestamp = timestamp_begin
  854. # long form generation: we need to handle the case where the call to generate returns concatenated segments,
  855. # with underlying multiple calls to generate
  856. cur_max_timestamp = 0.0
  857. prev_segments_len = 0.0
  858. penultimate_timestamp = 0.0
  859. if "stride" in output:
  860. chunk_len, stride_left, stride_right = output["stride"]
  861. # Offset the timings to account for the other `model_outputs`.
  862. time_offset -= stride_left
  863. right_stride_start = chunk_len - stride_right
  864. # Keeping track of timestamps within strides
  865. # We're going to NOT split on those, and delay until we're
  866. # out of BOTH stride. Otherwise lots of issues occur and
  867. # corner cases
  868. if stride_left:
  869. first_timestamp = stride_left / time_precision + timestamp_begin
  870. if stride_right:
  871. for token in reversed(token_ids):
  872. if token >= timestamp_begin:
  873. # There can be several token in the right stride
  874. # But the last one is ALWAYS going to be skipped
  875. if (
  876. last_timestamp is not None
  877. and (token - timestamp_begin) * time_precision < right_stride_start
  878. ):
  879. break
  880. last_timestamp = token
  881. current_tokens = []
  882. current_token_timestamps = []
  883. # - all tokens within output
  884. for i, token in enumerate(token_ids):
  885. # 4 possible states for each token
  886. # - 1/ Language code
  887. # - 2/ all other special tokens (which we ignore)
  888. # - 3/ Timestamp
  889. # - 4/ Regular text
  890. if token in all_special_ids:
  891. # Either language code or other
  892. text = tokenizer.decode([token])
  893. # Removing outer shell <|XX|>
  894. text = text[2:-2]
  895. language = LANGUAGES.get(text)
  896. if language is not None:
  897. # 1/ Indeed some language
  898. # TODO Handle when language is different from the previous
  899. # one, and we cannot use timestamped tokens to create chunks
  900. if last_language and language != last_language and not return_timestamps:
  901. previous_tokens.append(current_tokens)
  902. resolved_tokens = _find_longest_common_sequence(previous_tokens)
  903. resolved_text = tokenizer.decode(resolved_tokens)
  904. chunk["text"] = resolved_text
  905. chunks.append(chunk)
  906. # Flush all our temporary context
  907. previous_tokens = []
  908. current_tokens = []
  909. chunk = new_chunk()
  910. chunk["language"] = language
  911. last_language = language
  912. else:
  913. # 2/ This is a regular special token, ignoring it
  914. pass
  915. elif token >= timestamp_begin:
  916. # 3/ Timestamp token
  917. timestamp = float((token - timestamp_begin) * time_precision)
  918. if timestamp < cur_max_timestamp:
  919. # next segment has started
  920. last_was_single_ending = i >= 2 and not (
  921. token_ids[i - 1] >= timestamp_begin and token_ids[i - 2] >= timestamp_begin
  922. )
  923. if last_was_single_ending:
  924. prev_segments_len += time_precision * segment_size
  925. else:
  926. cur_max_timestamp = penultimate_timestamp
  927. prev_segments_len += penultimate_timestamp
  928. penultimate_timestamp = cur_max_timestamp
  929. cur_max_timestamp = timestamp
  930. time = (token - timestamp_begin) * time_precision + time_offset + prev_segments_len
  931. time = round(time, 2)
  932. if last_timestamp and token >= last_timestamp:
  933. # Whisper outputted a timestamp token, but it falls within
  934. # our stride, so we're going to skip it for the time being
  935. # and resolve this later
  936. # Skip is necessary because timestamp tokens always come
  937. # by pair, so we need to skip the next one too (which would mark the start of another chunk).
  938. skip = True
  939. elif skip or (previous_tokens and token < first_timestamp):
  940. skip = False
  941. elif chunk["timestamp"][0] is None:
  942. chunk["timestamp"][0] = time
  943. else:
  944. # This is the end of the timestamp chunk
  945. if time == chunk["timestamp"][0]:
  946. # This is a bug in timestamp token output
  947. # where we're taking the duplicate token
  948. # as a stop where it should be a start.
  949. # This is an issue in the underlying model output
  950. # Let's just skip it so it becomes de-factor
  951. # a start again
  952. pass
  953. else:
  954. chunk["timestamp"][1] = time
  955. # Handling merges.
  956. previous_tokens.append(current_tokens)
  957. if return_timestamps == "word":
  958. previous_token_timestamps.append(current_token_timestamps)
  959. resolved_tokens, resolved_token_timestamps = _find_longest_common_sequence(
  960. previous_tokens, previous_token_timestamps
  961. )
  962. resolved_text = tokenizer.decode(resolved_tokens)
  963. chunk["text"] = resolved_text
  964. if return_timestamps == "word":
  965. chunk["words"] = _collate_word_timestamps(
  966. tokenizer, resolved_tokens, resolved_token_timestamps, last_language, return_language
  967. )
  968. chunks.append(chunk)
  969. # Flush all our temporary context
  970. previous_tokens = []
  971. current_tokens = []
  972. previous_token_timestamps = []
  973. current_token_timestamps = []
  974. chunk = new_chunk()
  975. else:
  976. # 4/ Regular token
  977. # We just append to the list of all tokens so we can handle
  978. # merges later and decode into text.
  979. current_tokens.append(token)
  980. if return_timestamps == "word":
  981. if i == 0:
  982. start_time = round(0.0 + time_offset, 2)
  983. else:
  984. start_time = round(token_timestamps[i - 1] + time_offset, 2)
  985. end_time = round(token_timestamps[i] + time_offset, 2)
  986. current_token_timestamps.append((start_time, end_time))
  987. if "stride" in output:
  988. time_offset += chunk_len - stride_right
  989. # Leftover tokens
  990. if current_tokens:
  991. previous_tokens.append(current_tokens)
  992. if return_timestamps == "word":
  993. previous_token_timestamps.append(current_token_timestamps)
  994. elif not (any(p for p in previous_tokens)):
  995. chunk = new_chunk()
  996. previous_tokens = []
  997. current_tokens = []
  998. previous_token_timestamps = []
  999. current_token_timestamps = []
  1000. if previous_tokens:
  1001. if return_timestamps:
  1002. logger.warning(
  1003. "Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. "
  1004. "Also make sure WhisperTimeStampLogitsProcessor was used during generation."
  1005. )
  1006. # Happens when we don't use timestamps
  1007. resolved_tokens, resolved_token_timestamps = _find_longest_common_sequence(
  1008. previous_tokens, previous_token_timestamps
  1009. )
  1010. resolved_text = tokenizer.decode(resolved_tokens)
  1011. chunk["text"] = resolved_text
  1012. if return_timestamps == "word":
  1013. chunk["words"] = _collate_word_timestamps(
  1014. tokenizer, resolved_tokens, resolved_token_timestamps, last_language, return_language
  1015. )
  1016. chunks.append(chunk)
  1017. # Preparing and cleaning up the pipeline output
  1018. full_text = "".join(chunk["text"] for chunk in chunks)
  1019. if return_timestamps or return_language:
  1020. for chunk in chunks:
  1021. if not return_timestamps:
  1022. chunk.pop("timestamp")
  1023. else:
  1024. chunk["timestamp"] = tuple(chunk["timestamp"])
  1025. if not return_language:
  1026. chunk.pop("language")
  1027. if return_timestamps == "word":
  1028. new_chunks = []
  1029. for chunk in chunks:
  1030. new_chunks.extend(chunk["words"])
  1031. optional = {"chunks": new_chunks}
  1032. else:
  1033. optional = {"chunks": chunks}
  1034. else:
  1035. optional = {}
  1036. return full_text, optional
  1037. def _find_longest_common_sequence(sequences, token_timestamp_sequences=None):
  1038. # It would be much harder to do O(n) because of fault tolerance.
  1039. # We actually have a really good property which is that the total sequence
  1040. # MUST be those subsequences in order.
  1041. # If token_timestamp_sequences is provided, will split those sequences in
  1042. # exactly the same way.
  1043. left_sequence = sequences[0]
  1044. left_length = len(left_sequence)
  1045. total_sequence = []
  1046. if token_timestamp_sequences:
  1047. left_token_timestamp_sequence = token_timestamp_sequences[0]
  1048. total_token_timestamp_sequence = []
  1049. for seq_idx, right_sequence in enumerate(sequences[1:]):
  1050. # index = 0
  1051. max_ = 0.0
  1052. max_indices = (left_length, left_length, 0, 0)
  1053. # Here we're sliding matches
  1054. # [a, b, c, d]
  1055. # [c, d, f]
  1056. # = [c] == [d]
  1057. #
  1058. # [a, b, c, d]
  1059. # [c, d, f]
  1060. # = [c, d] == [c, d]
  1061. #
  1062. #
  1063. # [a, b, c, d]
  1064. # [c, d, f]
  1065. #
  1066. # = [b, c, d] == [c, d, f]
  1067. #
  1068. # [a, b, c, d]
  1069. # [c, d, f]
  1070. #
  1071. # [a, b, c] == [c, d, f]
  1072. #
  1073. # [a, b, c, d]
  1074. # [d, f]
  1075. #
  1076. # [a, b] == [d, f]
  1077. #
  1078. # [a, b, c, d]
  1079. # [f]
  1080. #
  1081. # [a] == [f]
  1082. right_length = len(right_sequence)
  1083. for i in range(1, left_length + right_length):
  1084. # epsilon to favor long perfect matches
  1085. eps = i / 10000.0
  1086. # Slightly convoluted because we don't want out of bound indices
  1087. # This will be necessary for a small conflict resolution optimization
  1088. # later
  1089. left_start = max(0, left_length - i)
  1090. left_stop = min(left_length, left_length + right_length - i)
  1091. left = np.array(left_sequence[left_start:left_stop])
  1092. right_start = max(0, i - left_length)
  1093. right_stop = min(right_length, i)
  1094. right = np.array(right_sequence[right_start:right_stop])
  1095. # We can only match subsequences of the same size.
  1096. if len(left) != len(right):
  1097. raise RuntimeError(
  1098. "There is a bug within whisper `decode_asr` function, please report it. Dropping to prevent bad inference."
  1099. )
  1100. if token_timestamp_sequences:
  1101. # Get length of longest subsequence of tokens that match
  1102. # and have timestamps that are in order
  1103. matches = sum(
  1104. 1
  1105. for idx, elem in enumerate(left)
  1106. if (
  1107. elem == right[idx]
  1108. and left_token_timestamp_sequence[left_start + idx]
  1109. <= token_timestamp_sequences[seq_idx + 1][right_start + idx]
  1110. )
  1111. )
  1112. else:
  1113. matches = np.sum(left == right)
  1114. matching = matches / i + eps
  1115. if matches > 1 and matching > max_:
  1116. max_ = matching
  1117. max_indices = (left_start, left_stop, right_start, right_stop)
  1118. (left_start, left_stop, right_start, right_stop) = max_indices
  1119. # This is a small conflict optimization since those sequences overlap
  1120. # in audio.
  1121. # We're going to give more confidence to the left sequence
  1122. # for the left of the overlap,
  1123. # and to the right of the sequence, for the right of the overlap
  1124. left_mid = (left_stop + left_start) // 2
  1125. right_mid = (right_stop + right_start) // 2
  1126. total_sequence.extend(left_sequence[:left_mid])
  1127. left_sequence = right_sequence[right_mid:]
  1128. left_length = len(left_sequence)
  1129. if token_timestamp_sequences:
  1130. total_token_timestamp_sequence.extend(left_token_timestamp_sequence[:left_mid])
  1131. left_token_timestamp_sequence = token_timestamp_sequences[seq_idx + 1][right_mid:]
  1132. total_sequence.extend(left_sequence)
  1133. if token_timestamp_sequences is None:
  1134. return total_sequence
  1135. if len(token_timestamp_sequences) > 0:
  1136. total_token_timestamp_sequence.extend(left_token_timestamp_sequence)
  1137. return total_sequence, total_token_timestamp_sequence
  1138. else:
  1139. return total_sequence, []
  1140. def _collate_word_timestamps(tokenizer, tokens, token_timestamps, language, return_language):
  1141. words, _, token_indices = _combine_tokens_into_words(tokenizer, tokens, language)
  1142. optional_language_field = {"language": language} if return_language else {}
  1143. timings = [
  1144. {
  1145. "text": word,
  1146. "timestamp": (token_timestamps[indices[0]][0], token_timestamps[indices[-1]][1]),
  1147. **optional_language_field,
  1148. }
  1149. for word, indices in zip(words, token_indices)
  1150. ]
  1151. return timings
  1152. def _combine_tokens_into_words(
  1153. tokenizer,
  1154. tokens: list[int],
  1155. language: Optional[str] = None,
  1156. prepend_punctuations: str = "\"'“¡¿([{-",
  1157. append_punctuations: str = "\"'.。,,!!??::”)]}、",
  1158. ):
  1159. """
  1160. Groups tokens by word. Returns a tuple containing a list of strings with the words, and a list of `token_id`
  1161. sequences with the tokens making up each word.
  1162. """
  1163. if language is None:
  1164. language = tokenizer.language
  1165. if language is None:
  1166. language = "english"
  1167. if language in {"chinese", "japanese", "thai", "lao", "myanmar", "cantonese"}:
  1168. # These languages don't typically use spaces.
  1169. words, word_tokens, token_indices = _split_tokens_on_unicode(tokenizer, tokens)
  1170. else:
  1171. words, word_tokens, token_indices = _split_tokens_on_spaces(tokenizer, tokens)
  1172. _merge_punctuations(words, word_tokens, token_indices, prepend_punctuations, append_punctuations)
  1173. return words, word_tokens, token_indices
  1174. def _split_tokens_on_unicode(tokenizer, tokens: list[int]):
  1175. """Combine tokens into words by splitting at any position where the tokens are decoded as valid unicode points."""
  1176. decoded_full = tokenizer.decode(tokens, decode_with_timestamps=True)
  1177. replacement_char = "\ufffd"
  1178. words = []
  1179. word_tokens = []
  1180. token_indices = []
  1181. current_tokens = []
  1182. current_indices = []
  1183. unicode_offset = 0
  1184. for token_idx, token in enumerate(tokens):
  1185. current_tokens.append(token)
  1186. current_indices.append(token_idx)
  1187. decoded = tokenizer.decode(current_tokens, decode_with_timestamps=True)
  1188. if (
  1189. replacement_char not in decoded
  1190. or decoded_full[unicode_offset + decoded.index(replacement_char)] == replacement_char
  1191. ):
  1192. words.append(decoded)
  1193. word_tokens.append(current_tokens)
  1194. token_indices.append(current_indices)
  1195. current_tokens = []
  1196. current_indices = []
  1197. unicode_offset += len(decoded)
  1198. return words, word_tokens, token_indices
  1199. def _split_tokens_on_spaces(tokenizer, tokens: list[int]):
  1200. """Combine tokens into words by splitting at whitespace and punctuation tokens."""
  1201. subwords, subword_tokens_list, subword_indices_list = _split_tokens_on_unicode(tokenizer, tokens)
  1202. words = []
  1203. word_tokens = []
  1204. token_indices = []
  1205. for subword, subword_tokens, subword_indices in zip(subwords, subword_tokens_list, subword_indices_list):
  1206. special = subword_tokens[0] >= tokenizer.eos_token_id
  1207. with_space = subword.startswith(" ")
  1208. punctuation = subword.strip() in "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
  1209. if special or with_space or punctuation or len(words) == 0:
  1210. words.append(subword)
  1211. word_tokens.append(subword_tokens)
  1212. token_indices.append(subword_indices)
  1213. else:
  1214. words[-1] = words[-1] + subword
  1215. word_tokens[-1].extend(subword_tokens)
  1216. token_indices[-1].extend(subword_indices)
  1217. return words, word_tokens, token_indices
  1218. def _merge_punctuations(words, tokens, indices, prepended, appended):
  1219. """Merges punctuation tokens with neighboring words."""
  1220. # prepend punctuations
  1221. i = len(words) - 2
  1222. j = len(words) - 1
  1223. while i >= 0:
  1224. if words[i].startswith(" ") and words[i].strip() in prepended:
  1225. words[j] = words[i] + words[j]
  1226. tokens[j] = tokens[i] + tokens[j]
  1227. indices[j] = indices[i] + indices[j]
  1228. words[i] = ""
  1229. tokens[i] = []
  1230. indices[i] = []
  1231. else:
  1232. j = i
  1233. i -= 1
  1234. # append punctuations
  1235. i = 0
  1236. j = 1
  1237. while j < len(words):
  1238. if not words[i].endswith(" ") and words[j] in appended:
  1239. words[i] += words[j]
  1240. tokens[i] += tokens[j]
  1241. indices[i] += indices[j]
  1242. words[j] = ""
  1243. tokens[j] = []
  1244. indices[j] = []
  1245. else:
  1246. i = j
  1247. j += 1
  1248. # remove elements that are now empty
  1249. words[:] = [word for word in words if word]
  1250. tokens[:] = [token for token in tokens if token]
  1251. indices[:] = [idx for idx in indices if idx]
  1252. __all__ = ["WhisperTokenizer"]