tokenization_fsmt.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. # coding=utf-8
  2. # Copyright 2019 The Open AI Team Authors and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Tokenization classes for FSMT."""
  16. import json
  17. import os
  18. import re
  19. import unicodedata
  20. from typing import Optional
  21. from ...tokenization_utils import PreTrainedTokenizer
  22. from ...utils import logging
  23. logger = logging.get_logger(__name__)
  24. VOCAB_FILES_NAMES = {
  25. "src_vocab_file": "vocab-src.json",
  26. "tgt_vocab_file": "vocab-tgt.json",
  27. "merges_file": "merges.txt",
  28. }
  29. def get_pairs(word):
  30. """
  31. Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
  32. strings)
  33. """
  34. pairs = set()
  35. prev_char = word[0]
  36. for char in word[1:]:
  37. pairs.add((prev_char, char))
  38. prev_char = char
  39. return pairs
  40. def replace_unicode_punct(text):
  41. """
  42. Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
  43. """
  44. text = text.replace(",", ",")
  45. text = re.sub(r"。\s*", ". ", text)
  46. text = text.replace("、", ",")
  47. text = text.replace("”", '"')
  48. text = text.replace("“", '"')
  49. text = text.replace("∶", ":")
  50. text = text.replace(":", ":")
  51. text = text.replace("?", "?")
  52. text = text.replace("《", '"')
  53. text = text.replace("》", '"')
  54. text = text.replace(")", ")")
  55. text = text.replace("!", "!")
  56. text = text.replace("(", "(")
  57. text = text.replace(";", ";")
  58. text = text.replace("1", "1")
  59. text = text.replace("」", '"')
  60. text = text.replace("「", '"')
  61. text = text.replace("0", "0")
  62. text = text.replace("3", "3")
  63. text = text.replace("2", "2")
  64. text = text.replace("5", "5")
  65. text = text.replace("6", "6")
  66. text = text.replace("9", "9")
  67. text = text.replace("7", "7")
  68. text = text.replace("8", "8")
  69. text = text.replace("4", "4")
  70. text = re.sub(r".\s*", ". ", text)
  71. text = text.replace("~", "~")
  72. text = text.replace("’", "'")
  73. text = text.replace("…", "...")
  74. text = text.replace("━", "-")
  75. text = text.replace("〈", "<")
  76. text = text.replace("〉", ">")
  77. text = text.replace("【", "[")
  78. text = text.replace("】", "]")
  79. text = text.replace("%", "%")
  80. return text
  81. def remove_non_printing_char(text):
  82. """
  83. Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
  84. """
  85. output = []
  86. for char in text:
  87. cat = unicodedata.category(char)
  88. if cat.startswith("C"):
  89. continue
  90. output.append(char)
  91. return "".join(output)
  92. # Porting notes:
  93. # this one is modeled after XLMTokenizer
  94. #
  95. # added:
  96. # - src_vocab_file,
  97. # - tgt_vocab_file,
  98. # - langs,
  99. class FSMTTokenizer(PreTrainedTokenizer):
  100. """
  101. Construct an FAIRSEQ Transformer tokenizer. Based on Byte-Pair Encoding. The tokenization process is the following:
  102. - Moses preprocessing and tokenization.
  103. - Normalizing all inputs text.
  104. - The arguments `special_tokens` and the function `set_special_tokens`, can be used to add additional symbols (like
  105. "__classify__") to a vocabulary.
  106. - The argument `langs` defines a pair of languages.
  107. This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
  108. this superclass for more information regarding those methods.
  109. Args:
  110. langs (`List[str]`, *optional*):
  111. A list of two languages to translate from and to, for instance `["en", "ru"]`.
  112. src_vocab_file (`str`, *optional*):
  113. File containing the vocabulary for the source language.
  114. tgt_vocab_file (`st`, *optional*):
  115. File containing the vocabulary for the target language.
  116. merges_file (`str`, *optional*):
  117. File containing the merges.
  118. do_lower_case (`bool`, *optional*, defaults to `False`):
  119. Whether or not to lowercase the input when tokenizing.
  120. unk_token (`str`, *optional*, defaults to `"<unk>"`):
  121. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  122. token instead.
  123. bos_token (`str`, *optional*, defaults to `"<s>"`):
  124. The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
  125. <Tip>
  126. When building a sequence using special tokens, this is not the token that is used for the beginning of
  127. sequence. The token used is the `cls_token`.
  128. </Tip>
  129. sep_token (`str`, *optional*, defaults to `"</s>"`):
  130. The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
  131. sequence classification or for a text and a question for question answering. It is also used as the last
  132. token of a sequence built with special tokens.
  133. pad_token (`str`, *optional*, defaults to `"<pad>"`):
  134. The token used for padding, for example when batching sequences of different lengths.
  135. """
  136. vocab_files_names = VOCAB_FILES_NAMES
  137. model_input_names = ["input_ids", "attention_mask"]
  138. def __init__(
  139. self,
  140. langs=None,
  141. src_vocab_file=None,
  142. tgt_vocab_file=None,
  143. merges_file=None,
  144. do_lower_case=False,
  145. unk_token="<unk>",
  146. bos_token="<s>",
  147. sep_token="</s>",
  148. pad_token="<pad>",
  149. **kwargs,
  150. ):
  151. try:
  152. import sacremoses
  153. except ImportError:
  154. raise ImportError(
  155. "You need to install sacremoses to use XLMTokenizer. "
  156. "See https://pypi.org/project/sacremoses/ for installation."
  157. )
  158. self.sm = sacremoses
  159. self.src_vocab_file = src_vocab_file
  160. self.tgt_vocab_file = tgt_vocab_file
  161. self.merges_file = merges_file
  162. self.do_lower_case = do_lower_case
  163. # cache of sm.MosesPunctNormalizer instance
  164. self.cache_moses_punct_normalizer = {}
  165. # cache of sm.MosesTokenizer instance
  166. self.cache_moses_tokenizer = {}
  167. self.cache_moses_detokenizer = {}
  168. if langs and len(langs) == 2:
  169. self.src_lang, self.tgt_lang = langs
  170. else:
  171. raise ValueError(
  172. f"arg `langs` needs to be a list of 2 langs, e.g. ['en', 'ru'], but got {langs}. "
  173. "Usually that means that tokenizer can't find a mapping for the given model path "
  174. "in and other maps of this tokenizer."
  175. )
  176. with open(src_vocab_file, encoding="utf-8") as src_vocab_handle:
  177. self.encoder = json.load(src_vocab_handle)
  178. with open(tgt_vocab_file, encoding="utf-8") as tgt_vocab_handle:
  179. tgt_vocab = json.load(tgt_vocab_handle)
  180. self.decoder = {v: k for k, v in tgt_vocab.items()}
  181. with open(merges_file, encoding="utf-8") as merges_handle:
  182. merges = merges_handle.read().split("\n")[:-1]
  183. merges = [tuple(merge.split()[:2]) for merge in merges]
  184. self.bpe_ranks = dict(zip(merges, range(len(merges))))
  185. self.cache = {}
  186. super().__init__(
  187. langs=langs,
  188. src_vocab_file=src_vocab_file,
  189. tgt_vocab_file=tgt_vocab_file,
  190. merges_file=merges_file,
  191. do_lower_case=do_lower_case,
  192. unk_token=unk_token,
  193. bos_token=bos_token,
  194. sep_token=sep_token,
  195. pad_token=pad_token,
  196. **kwargs,
  197. )
  198. # hack override
  199. def get_vocab(self) -> dict[str, int]:
  200. return self.get_src_vocab()
  201. # hack override
  202. @property
  203. def vocab_size(self) -> int:
  204. return self.src_vocab_size
  205. def moses_punct_norm(self, text, lang):
  206. if lang not in self.cache_moses_punct_normalizer:
  207. punct_normalizer = self.sm.MosesPunctNormalizer(lang=lang)
  208. self.cache_moses_punct_normalizer[lang] = punct_normalizer
  209. return self.cache_moses_punct_normalizer[lang].normalize(text)
  210. def moses_tokenize(self, text, lang):
  211. if lang not in self.cache_moses_tokenizer:
  212. moses_tokenizer = self.sm.MosesTokenizer(lang=lang)
  213. self.cache_moses_tokenizer[lang] = moses_tokenizer
  214. return self.cache_moses_tokenizer[lang].tokenize(
  215. text, aggressive_dash_splits=True, return_str=False, escape=True
  216. )
  217. def moses_detokenize(self, tokens, lang):
  218. if lang not in self.cache_moses_detokenizer:
  219. moses_detokenizer = self.sm.MosesDetokenizer(lang=lang)
  220. self.cache_moses_detokenizer[lang] = moses_detokenizer
  221. return self.cache_moses_detokenizer[lang].detokenize(tokens)
  222. def moses_pipeline(self, text, lang):
  223. text = replace_unicode_punct(text)
  224. text = self.moses_punct_norm(text, lang)
  225. text = remove_non_printing_char(text)
  226. return text
  227. @property
  228. def src_vocab_size(self):
  229. return len(self.encoder)
  230. @property
  231. def tgt_vocab_size(self):
  232. return len(self.decoder)
  233. def get_src_vocab(self):
  234. return dict(self.encoder, **self.added_tokens_encoder)
  235. def get_tgt_vocab(self):
  236. return dict(self.decoder, **self.added_tokens_decoder)
  237. def bpe(self, token):
  238. word = tuple(token[:-1]) + (token[-1] + "</w>",)
  239. if token in self.cache:
  240. return self.cache[token]
  241. pairs = get_pairs(word)
  242. if not pairs:
  243. return token + "</w>"
  244. while True:
  245. bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
  246. if bigram not in self.bpe_ranks:
  247. break
  248. first, second = bigram
  249. new_word = []
  250. i = 0
  251. while i < len(word):
  252. try:
  253. j = word.index(first, i)
  254. except ValueError:
  255. new_word.extend(word[i:])
  256. break
  257. else:
  258. new_word.extend(word[i:j])
  259. i = j
  260. if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
  261. new_word.append(first + second)
  262. i += 2
  263. else:
  264. new_word.append(word[i])
  265. i += 1
  266. new_word = tuple(new_word)
  267. word = new_word
  268. if len(word) == 1:
  269. break
  270. else:
  271. pairs = get_pairs(word)
  272. word = " ".join(word)
  273. if word == "\n </w>":
  274. word = "\n</w>"
  275. self.cache[token] = word
  276. return word
  277. def _tokenize(self, text, lang="en", bypass_tokenizer=False):
  278. """
  279. Tokenize a string given language code using Moses.
  280. Details of tokenization:
  281. - [sacremoses](https://github.com/alvations/sacremoses): port of Moses
  282. - Install with `pip install sacremoses`
  283. Args:
  284. - lang: ISO language code (default = 'en') (string). Languages should belong of the model supported
  285. languages. However, we don't enforce it.
  286. - bypass_tokenizer: Allow users to preprocess and tokenize the sentences externally (default = False)
  287. (bool). If True, we only apply BPE.
  288. Returns:
  289. List of tokens.
  290. """
  291. # ignore `lang` which is currently isn't explicitly passed in tokenization_utils.py and always results in lang=en
  292. # if lang != self.src_lang:
  293. # raise ValueError(f"Expected lang={self.src_lang}, but got {lang}")
  294. lang = self.src_lang
  295. if self.do_lower_case:
  296. text = text.lower()
  297. if bypass_tokenizer:
  298. text = text.split()
  299. else:
  300. text = self.moses_pipeline(text, lang=lang)
  301. text = self.moses_tokenize(text, lang=lang)
  302. split_tokens = []
  303. for token in text:
  304. if token:
  305. split_tokens.extend(list(self.bpe(token).split(" ")))
  306. return split_tokens
  307. def _convert_token_to_id(self, token):
  308. """Converts a token (str) in an id using the vocab."""
  309. return self.encoder.get(token, self.encoder.get(self.unk_token))
  310. def _convert_id_to_token(self, index):
  311. """Converts an index (integer) in a token (str) using the vocab."""
  312. return self.decoder.get(index, self.unk_token)
  313. def convert_tokens_to_string(self, tokens):
  314. """Converts a sequence of tokens (string) in a single string."""
  315. # remove BPE
  316. tokens = [t.replace(" ", "").replace("</w>", " ") for t in tokens]
  317. tokens = "".join(tokens).split()
  318. # detokenize
  319. text = self.moses_detokenize(tokens, self.tgt_lang)
  320. return text
  321. def build_inputs_with_special_tokens(
  322. self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
  323. ) -> list[int]:
  324. """
  325. Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
  326. adding special tokens. A FAIRSEQ Transformer sequence has the following format:
  327. - single sequence: `<s> X </s>`
  328. - pair of sequences: `<s> A </s> B </s>`
  329. Args:
  330. token_ids_0 (`List[int]`):
  331. List of IDs to which the special tokens will be added.
  332. token_ids_1 (`List[int]`, *optional*):
  333. Optional second list of IDs for sequence pairs.
  334. Returns:
  335. `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
  336. """
  337. sep = [self.sep_token_id]
  338. # no bos used in fairseq
  339. if token_ids_1 is None:
  340. return token_ids_0 + sep
  341. return token_ids_0 + sep + token_ids_1 + sep
  342. def get_special_tokens_mask(
  343. self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
  344. ) -> list[int]:
  345. """
  346. Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
  347. special tokens using the tokenizer `prepare_for_model` method.
  348. Args:
  349. token_ids_0 (`List[int]`):
  350. List of IDs.
  351. token_ids_1 (`List[int]`, *optional*):
  352. Optional second list of IDs for sequence pairs.
  353. already_has_special_tokens (`bool`, *optional*, defaults to `False`):
  354. Whether or not the token list is already formatted with special tokens for the model.
  355. Returns:
  356. `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
  357. """
  358. if already_has_special_tokens:
  359. return super().get_special_tokens_mask(
  360. token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
  361. )
  362. # no bos used in fairseq
  363. if token_ids_1 is not None:
  364. return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
  365. return ([0] * len(token_ids_0)) + [1]
  366. def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
  367. if not os.path.isdir(save_directory):
  368. logger.error(f"Vocabulary path ({save_directory}) should be a directory")
  369. return
  370. src_vocab_file = os.path.join(
  371. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["src_vocab_file"]
  372. )
  373. tgt_vocab_file = os.path.join(
  374. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["tgt_vocab_file"]
  375. )
  376. merges_file = os.path.join(
  377. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
  378. )
  379. with open(src_vocab_file, "w", encoding="utf-8") as f:
  380. f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
  381. with open(tgt_vocab_file, "w", encoding="utf-8") as f:
  382. tgt_vocab = {v: k for k, v in self.decoder.items()}
  383. f.write(json.dumps(tgt_vocab, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
  384. index = 0
  385. with open(merges_file, "w", encoding="utf-8") as writer:
  386. for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
  387. if index != token_index:
  388. logger.warning(
  389. f"Saving vocabulary to {merges_file}: BPE merge indices are not consecutive."
  390. " Please check that the tokenizer is not corrupted!"
  391. )
  392. index = token_index
  393. writer.write(" ".join(bpe_tokens) + "\n")
  394. index += 1
  395. return src_vocab_file, tgt_vocab_file, merges_file
  396. def __getstate__(self):
  397. state = self.__dict__.copy()
  398. state["sm"] = None
  399. return state
  400. def __setstate__(self, d):
  401. self.__dict__ = d
  402. try:
  403. import sacremoses
  404. except ImportError:
  405. raise ImportError(
  406. "You need to install sacremoses to use XLMTokenizer. "
  407. "See https://pypi.org/project/sacremoses/ for installation."
  408. )
  409. self.sm = sacremoses
  410. __all__ = ["FSMTTokenizer"]