tokenizer.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from __future__ import (absolute_import, division, print_function,
  3. unicode_literals)
  4. import collections
  5. import logging
  6. import os
  7. import sys
  8. import unicodedata
  9. import json
  10. import regex as re
  11. def clean_string(string):
  12. replace_mp = {
  13. ' - ': '-',
  14. " ' ": "'",
  15. " n't": "n't",
  16. " 'm": "'m",
  17. ' do not': " don't",
  18. " 's": "'s",
  19. " 've": "'ve",
  20. " 're": "'re"
  21. }
  22. for k, v in replace_mp.items():
  23. string = string.replace(k, v)
  24. return string
  25. class Tokenizer(object):
  26. def __init__(self, vocab_path, special_tokens=[], tokenizer_type='Bert'):
  27. self.tokenizer_type = tokenizer_type
  28. if tokenizer_type == 'Bert':
  29. self.spec_convert_dict = {
  30. '[BOS]': '[unused0]',
  31. '[EOS]': '[unused1]'
  32. }
  33. for token in special_tokens:
  34. if token not in self.spec_convert_dict and token not in [
  35. '[PAD]', '[UNK]'
  36. ]:
  37. self.spec_convert_dict[
  38. token] = f'[unused{len(self.spec_convert_dict)}]'
  39. self.spec_revert_dict = {
  40. v: k
  41. for k, v in self.spec_convert_dict.items()
  42. }
  43. special_tokens = [
  44. self.spec_convert_dict.get(tok, tok) for tok in special_tokens
  45. ]
  46. self.special_tokens = ('[UNK]', '[SEP]', '[PAD]', '[CLS]',
  47. '[MASK]')
  48. self.special_tokens += tuple(x for x in special_tokens
  49. if x not in self.special_tokens)
  50. self._tokenizer = BertTokenizer(
  51. vocab_path, never_split=self.special_tokens)
  52. for tok in self.special_tokens:
  53. assert tok in self._tokenizer.vocab, f"special token '{tok}' is not in the vocabulary"
  54. self.vocab_size = len(self._tokenizer.vocab)
  55. elif tokenizer_type == 'GPT2':
  56. self.spec_convert_dict = {'[UNK]': '<unk>'}
  57. self.spec_revert_dict = {
  58. v: k
  59. for k, v in self.spec_convert_dict.items()
  60. }
  61. special_tokens = [
  62. tok for tok in special_tokens
  63. if tok not in self.spec_convert_dict
  64. ]
  65. vocab_file = os.path.join(vocab_path, 'vocab.json')
  66. merges_file = os.path.join(vocab_path, 'merges.txt')
  67. self._tokenizer = GPT2Tokenizer(
  68. vocab_file, merges_file, special_tokens=special_tokens)
  69. self.num_specials = len(special_tokens)
  70. self.vocab_size = len(self._tokenizer)
  71. else:
  72. raise ValueError
  73. def tokenize(self, text):
  74. return self._tokenizer.tokenize(text)
  75. def convert_tokens_to_ids(self, tokens):
  76. if self.tokenizer_type == 'Bert':
  77. tokens = [self.spec_convert_dict.get(tok, tok) for tok in tokens]
  78. ids = self._tokenizer.convert_tokens_to_ids(tokens)
  79. return ids
  80. else:
  81. tokens = [self.spec_convert_dict.get(tok, tok) for tok in tokens]
  82. ids = self._tokenizer.convert_tokens_to_ids(tokens)
  83. ids = [(i + self.num_specials) % self.vocab_size for i in ids]
  84. return ids
  85. def convert_ids_to_tokens(self, ids):
  86. if self.tokenizer_type == 'Bert':
  87. tokens = self._tokenizer.convert_ids_to_tokens(ids)
  88. tokens = [self.spec_revert_dict.get(tok, tok) for tok in tokens]
  89. return tokens
  90. else:
  91. ids = [(i - self.num_specials) % self.vocab_size for i in ids]
  92. tokens = self._tokenizer.convert_ids_to_tokens(ids)
  93. tokens = [self.spec_revert_dict.get(tok, tok) for tok in tokens]
  94. return tokens
  95. def decode(self, ids, ignore_tokens=[]):
  96. tokens = self.convert_ids_to_tokens(ids)
  97. if len(ignore_tokens) > 0:
  98. ignore_tokens = set(ignore_tokens)
  99. tokens = [tok for tok in tokens if tok not in ignore_tokens]
  100. if self.tokenizer_type == 'Bert':
  101. string = ' '.join(tokens).replace(' ##', '')
  102. else:
  103. string = ''.join(tokens)
  104. string = bytearray([
  105. self._tokenizer.byte_decoder[c] for c in string
  106. ]).decode('utf-8')
  107. string = clean_string(string)
  108. return string
  109. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  110. #
  111. # Licensed under the Apache License, Version 2.0 (the "License");
  112. # you may not use this file except in compliance with the License.
  113. # You may obtain a copy of the License at
  114. #
  115. # http://www.apache.org/licenses/LICENSE-2.0
  116. #
  117. # Unless required by applicable law or agreed to in writing, software
  118. # distributed under the License is distributed on an "AS IS" BASIS,
  119. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  120. # See the License for the specific language governing permissions and
  121. # limitations under the License.
  122. """Tokenization classes."""
  123. logger = logging.getLogger(__name__)
  124. def load_vocab(vocab_file):
  125. """Loads a vocabulary file into a dictionary."""
  126. vocab = collections.OrderedDict()
  127. index = 0
  128. with open(vocab_file, 'r', encoding='utf-8') as reader:
  129. while True:
  130. token = reader.readline()
  131. if not token:
  132. break
  133. token = token.strip()
  134. vocab[token] = index
  135. index += 1
  136. return vocab
  137. def whitespace_tokenize(text):
  138. """Runs basic whitespace cleaning and splitting on a piece of text."""
  139. text = text.strip()
  140. if not text:
  141. return []
  142. tokens = text.split()
  143. return tokens
  144. class BertTokenizer(object):
  145. """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
  146. def __init__(self,
  147. vocab_file,
  148. do_lower_case=True,
  149. max_len=None,
  150. do_basic_tokenize=True,
  151. never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')):
  152. """Constructs a BertTokenizer.
  153. Args:
  154. vocab_file: Path to a one-wordpiece-per-line vocabulary file
  155. do_lower_case: Whether to lower case the input
  156. Only has an effect when do_wordpiece_only=False
  157. do_basic_tokenize: Whether to do basic tokenization before wordpiece.
  158. max_len: An artificial maximum length to truncate tokenized sequences to;
  159. Effective maximum length is always the minimum of this
  160. value (if specified) and the underlying BERT model's
  161. sequence length.
  162. never_split: List of tokens which will never be split during tokenization.
  163. Only has an effect when do_wordpiece_only=False
  164. """
  165. if not os.path.isfile(vocab_file):
  166. raise ValueError(
  167. "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
  168. 'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`'
  169. .format(vocab_file))
  170. self.vocab = load_vocab(vocab_file)
  171. self.ids_to_tokens = collections.OrderedDict([
  172. (ids, tok) for tok, ids in self.vocab.items()
  173. ])
  174. self.do_basic_tokenize = do_basic_tokenize
  175. if do_basic_tokenize:
  176. self.basic_tokenizer = BasicTokenizer(
  177. do_lower_case=do_lower_case, never_split=never_split)
  178. self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
  179. self.max_len = max_len if max_len is not None else int(1e12)
  180. def tokenize(self, text):
  181. split_tokens = []
  182. if self.do_basic_tokenize:
  183. for token in self.basic_tokenizer.tokenize(text):
  184. for sub_token in self.wordpiece_tokenizer.tokenize(token):
  185. split_tokens.append(sub_token)
  186. else:
  187. split_tokens = self.wordpiece_tokenizer.tokenize(text)
  188. return split_tokens
  189. def convert_tokens_to_ids(self, tokens):
  190. """Converts a sequence of tokens into ids using the vocab."""
  191. ids = []
  192. for token in tokens:
  193. ids.append(self.vocab[token])
  194. if len(ids) > self.max_len:
  195. logger.warning(
  196. 'Token indices sequence length is longer than the specified maximum '
  197. ' sequence length for this BERT model ({} > {}). Running this'
  198. ' sequence through BERT will result in indexing errors'.format(
  199. len(ids), self.max_len))
  200. return ids
  201. def convert_ids_to_tokens(self, ids):
  202. """Converts a sequence of ids in wordpiece tokens using the vocab."""
  203. tokens = []
  204. for i in ids:
  205. tokens.append(self.ids_to_tokens[i])
  206. return tokens
  207. class BasicTokenizer(object):
  208. """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
  209. def __init__(self,
  210. do_lower_case=True,
  211. never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')):
  212. """Constructs a BasicTokenizer.
  213. Args:
  214. do_lower_case: Whether to lower case the input.
  215. """
  216. self.do_lower_case = do_lower_case
  217. self.never_split = never_split
  218. def tokenize(self, text):
  219. """Tokenizes a piece of text."""
  220. text = self._clean_text(text)
  221. # This was added on November 1st, 2018 for the multilingual and Chinese
  222. # models. This is also applied to the English models now, but it doesn't
  223. # matter since the English models were not trained on any Chinese data
  224. # and generally don't have any Chinese data in them (there are Chinese
  225. # characters in the vocabulary because Wikipedia does have some Chinese
  226. # words in the English Wikipedia.).
  227. text = self._tokenize_chinese_chars(text)
  228. orig_tokens = whitespace_tokenize(text)
  229. split_tokens = []
  230. for token in orig_tokens:
  231. if self.do_lower_case and token not in self.never_split:
  232. token = token.lower()
  233. token = self._run_strip_accents(token)
  234. split_tokens.extend(self._run_split_on_punc(token))
  235. output_tokens = whitespace_tokenize(' '.join(split_tokens))
  236. return output_tokens
  237. def _run_strip_accents(self, text):
  238. """Strips accents from a piece of text."""
  239. text = unicodedata.normalize('NFD', text)
  240. output = []
  241. for char in text:
  242. cat = unicodedata.category(char)
  243. if cat == 'Mn':
  244. continue
  245. output.append(char)
  246. return ''.join(output)
  247. def _run_split_on_punc(self, text):
  248. """Splits punctuation on a piece of text."""
  249. if text in self.never_split:
  250. return [text]
  251. chars = list(text)
  252. i = 0
  253. start_new_word = True
  254. output = []
  255. while i < len(chars):
  256. char = chars[i]
  257. if _is_punctuation(char):
  258. output.append([char])
  259. start_new_word = True
  260. else:
  261. if start_new_word:
  262. output.append([])
  263. start_new_word = False
  264. output[-1].append(char)
  265. i += 1
  266. return [''.join(x) for x in output]
  267. def _tokenize_chinese_chars(self, text):
  268. """Adds whitespace around any CJK character."""
  269. output = []
  270. for char in text:
  271. cp = ord(char)
  272. if self._is_chinese_char(cp):
  273. output.append(' ')
  274. output.append(char)
  275. output.append(' ')
  276. else:
  277. output.append(char)
  278. return ''.join(output)
  279. def _is_chinese_char(self, cp):
  280. """Checks whether CP is the codepoint of a CJK character."""
  281. # This defines a "chinese character" as anything in the CJK Unicode block:
  282. # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
  283. #
  284. # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
  285. # despite its name. The modern Korean Hangul alphabet is a different block,
  286. # as is Japanese Hiragana and Katakana. Those alphabets are used to write
  287. # space-separated words, so they are not treated specially and handled
  288. # like the all of the other languages.
  289. tmp = (cp >= 0x4E00 and cp <= 0x9FFF)
  290. tmp = tmp or (cp >= 0x3400 and cp <= 0x4DBF)
  291. tmp = tmp or (cp >= 0x20000 and cp <= 0x2A6DF)
  292. tmp = tmp or (cp >= 0x2A700 and cp <= 0x2B73F)
  293. tmp = tmp or (cp >= 0x2B740 and cp <= 0x2B81F)
  294. tmp = tmp or (cp >= 0x2B820 and cp <= 0x2CEAF)
  295. tmp = tmp or (cp >= 0xF900 and cp <= 0xFAFF)
  296. tmp = tmp or (cp >= 0x2F800 and cp <= 0x2FA1F)
  297. if tmp:
  298. return True
  299. return False
  300. def _clean_text(self, text):
  301. """Performs invalid character removal and whitespace cleanup on text."""
  302. output = []
  303. for char in text:
  304. cp = ord(char)
  305. if cp == 0 or cp == 0xfffd or _is_control(char):
  306. continue
  307. if _is_whitespace(char):
  308. output.append(' ')
  309. else:
  310. output.append(char)
  311. return ''.join(output)
  312. class WordpieceTokenizer(object):
  313. """Runs WordPiece tokenization."""
  314. def __init__(self, vocab, unk_token='[UNK]', max_input_chars_per_word=100):
  315. self.vocab = vocab
  316. self.unk_token = unk_token
  317. self.max_input_chars_per_word = max_input_chars_per_word
  318. def tokenize(self, text):
  319. """Tokenizes a piece of text into its word pieces.
  320. This uses a greedy longest-match-first algorithm to perform tokenization
  321. using the given vocabulary.
  322. For example:
  323. >>> input = "unaffable"
  324. >>> output = ["un", "##aff", "##able"]
  325. Args:
  326. text: A single token or whitespace separated tokens. This should have
  327. already been passed through `BasicTokenizer`.
  328. Returns:
  329. A list of wordpiece tokens.
  330. """
  331. output_tokens = []
  332. for token in whitespace_tokenize(text):
  333. chars = list(token)
  334. if len(chars) > self.max_input_chars_per_word:
  335. output_tokens.append(self.unk_token)
  336. continue
  337. is_bad = False
  338. start = 0
  339. sub_tokens = []
  340. while start < len(chars):
  341. end = len(chars)
  342. cur_substr = None
  343. while start < end:
  344. substr = ''.join(chars[start:end])
  345. if start > 0:
  346. substr = '##' + substr
  347. if substr in self.vocab:
  348. cur_substr = substr
  349. break
  350. end -= 1
  351. if cur_substr is None:
  352. is_bad = True
  353. break
  354. sub_tokens.append(cur_substr)
  355. start = end
  356. if is_bad:
  357. output_tokens.append(self.unk_token)
  358. else:
  359. output_tokens.extend(sub_tokens)
  360. return output_tokens
  361. def _is_whitespace(char):
  362. """Checks whether `chars` is a whitespace character."""
  363. # \t, \n, and \r are technically control characters but we treat them
  364. # as whitespace since they are generally considered as such.
  365. if char == ' ' or char == '\t' or char == '\n' or char == '\r':
  366. return True
  367. cat = unicodedata.category(char)
  368. if cat == 'Zs':
  369. return True
  370. return False
  371. def _is_control(char):
  372. """Checks whether `chars` is a control character."""
  373. # These are technically control characters but we count them as whitespace
  374. # characters.
  375. if char == '\t' or char == '\n' or char == '\r':
  376. return False
  377. cat = unicodedata.category(char)
  378. if cat.startswith('C'):
  379. return True
  380. return False
  381. def _is_punctuation(char):
  382. """Checks whether `chars` is a punctuation character."""
  383. cp = ord(char)
  384. # We treat all non-letter/number ASCII as punctuation.
  385. # Characters such as "^", "$", and "`" are not in the Unicode
  386. # Punctuation class but we treat them as punctuation anyways, for
  387. # consistency.
  388. tmp = (cp >= 33 and cp <= 47)
  389. tmp = tmp or (cp >= 58 and cp <= 64)
  390. tmp = tmp or (cp >= 91 and cp <= 96)
  391. tmp = tmp or (cp >= 123 and cp <= 126)
  392. if tmp:
  393. return True
  394. cat = unicodedata.category(char)
  395. if cat.startswith('P'):
  396. return True
  397. return False
  398. # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
  399. #
  400. # Licensed under the Apache License, Version 2.0 (the "License");
  401. # you may not use this file except in compliance with the License.
  402. # You may obtain a copy of the License at
  403. #
  404. # http://www.apache.org/licenses/LICENSE-2.0
  405. #
  406. # Unless required by applicable law or agreed to in writing, software
  407. # distributed under the License is distributed on an "AS IS" BASIS,
  408. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  409. # See the License for the specific language governing permissions and
  410. # limitations under the License.
  411. """Tokenization classes for OpenAI GPT."""
  412. try:
  413. from functools import lru_cache
  414. except ImportError:
  415. # Just a dummy decorator to get the checks to run on python2
  416. # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
  417. def lru_cache():
  418. return lambda func: func
  419. @lru_cache()
  420. def bytes_to_unicode():
  421. """
  422. Returns list of utf-8 byte and a corresponding list of unicode strings.
  423. The reversible bpe codes work on unicode strings.
  424. This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
  425. When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
  426. This is a significant percentage of your normal, say, 32K bpe vocab.
  427. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
  428. And avoids mapping to whitespace/control characters the bpe code barfs on.
  429. """
  430. _chr = unichr if sys.version_info[0] == 2 else chr
  431. bs = list(range(ord('!'),
  432. ord('~') + 1)) + list(range(
  433. ord('¡'),
  434. ord('¬') + 1)) + list(range(ord('®'),
  435. ord('ÿ') + 1))
  436. cs = bs[:]
  437. n = 0
  438. for b in range(2**8):
  439. if b not in bs:
  440. bs.append(b)
  441. cs.append(2**8 + n)
  442. n += 1
  443. cs = [_chr(n) for n in cs]
  444. return dict(zip(bs, cs))
  445. def get_pairs(word):
  446. """Return set of symbol pairs in a word.
  447. Word is represented as tuple of symbols (symbols being variable-length strings).
  448. """
  449. pairs = set()
  450. prev_char = word[0]
  451. for char in word[1:]:
  452. pairs.add((prev_char, char))
  453. prev_char = char
  454. return pairs
  455. class GPT2Tokenizer(object):
  456. """
  457. GPT-2 BPE tokenizer. Peculiarities:
  458. - Byte-level BPE
  459. """
  460. def __init__(self,
  461. vocab_file,
  462. merges_file,
  463. errors='replace',
  464. special_tokens=None,
  465. max_len=None):
  466. self.max_len = max_len if max_len is not None else int(1e12)
  467. self.encoder = json.load(open(vocab_file, encoding='utf-8'))
  468. self.decoder = {v: k for k, v in self.encoder.items()}
  469. self.errors = errors # how to handle errors in decoding
  470. self.byte_encoder = bytes_to_unicode()
  471. self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
  472. bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
  473. bpe_merges = [tuple(merge.split()) for merge in bpe_data]
  474. self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
  475. self.cache = {}
  476. # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
  477. self.pat = re.compile(
  478. r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
  479. )
  480. self.special_tokens = {}
  481. self.special_tokens_decoder = {}
  482. self.set_special_tokens(special_tokens)
  483. def __len__(self):
  484. return len(self.encoder) + len(self.special_tokens)
  485. def set_special_tokens(self, special_tokens):
  486. """ Add a list of additional tokens to the encoder.
  487. The additional tokens are indexed starting from the last index of the
  488. current vocabulary in the order of the `special_tokens` list.
  489. """
  490. if not special_tokens:
  491. self.special_tokens = {}
  492. self.special_tokens_decoder = {}
  493. return
  494. self.special_tokens = dict((tok, len(self.encoder) + i)
  495. for i, tok in enumerate(special_tokens))
  496. self.special_tokens_decoder = {
  497. v: k
  498. for k, v in self.special_tokens.items()
  499. }
  500. logger.info('Special tokens {}'.format(self.special_tokens))
  501. def bpe(self, token):
  502. if token in self.cache:
  503. return self.cache[token]
  504. word = tuple(token)
  505. pairs = get_pairs(word)
  506. if not pairs:
  507. return token
  508. while True:
  509. bigram = min(
  510. pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
  511. if bigram not in self.bpe_ranks:
  512. break
  513. first, second = bigram
  514. new_word = []
  515. i = 0
  516. while i < len(word):
  517. try:
  518. j = word.index(first, i)
  519. new_word.extend(word[i:j])
  520. i = j
  521. except Exception:
  522. new_word.extend(word[i:])
  523. break
  524. if word[i] == first and i < len(word) - 1 and word[
  525. i + 1] == second:
  526. new_word.append(first + second)
  527. i += 2
  528. else:
  529. new_word.append(word[i])
  530. i += 1
  531. new_word = tuple(new_word)
  532. word = new_word
  533. if len(word) == 1:
  534. break
  535. else:
  536. pairs = get_pairs(word)
  537. word = ' '.join(word)
  538. self.cache[token] = word
  539. return word
  540. def tokenize(self, text):
  541. """ Tokenize a string. """
  542. bpe_tokens = []
  543. for token in re.findall(self.pat, text):
  544. token = ''.join(self.byte_encoder[ord(b)] for b in token
  545. if ord(b) in self.byte_encoder)
  546. if token == '':
  547. continue
  548. bpe_tokens.extend(
  549. bpe_token for bpe_token in self.bpe(token).split(' '))
  550. return bpe_tokens
  551. def convert_tokens_to_ids(self, tokens):
  552. """ Converts a sequence of tokens into ids using the vocab. """
  553. ids = []
  554. python_version_3 = isinstance(tokens, str)
  555. python_version_2 = (
  556. sys.version_info[0] == 2 and isinstance(tokens, unicode))
  557. if python_version_3 or python_version_2:
  558. if tokens in self.special_tokens:
  559. return self.special_tokens[tokens]
  560. else:
  561. return self.encoder.get(tokens, 0)
  562. for token in tokens:
  563. if token in self.special_tokens:
  564. ids.append(self.special_tokens[token])
  565. else:
  566. ids.append(self.encoder.get(token, 0))
  567. if len(ids) > self.max_len:
  568. logger.warning(
  569. 'Token indices sequence length is longer than the specified maximum '
  570. ' sequence length for this OpenAI GPT model ({} > {}). Running this'
  571. ' sequence through the model will result in indexing errors'.
  572. format(len(ids), self.max_len))
  573. return ids
  574. def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
  575. """Converts a sequence of ids in BPE tokens using the vocab."""
  576. tokens = []
  577. for i in ids:
  578. if i in self.special_tokens_decoder:
  579. if not skip_special_tokens:
  580. tokens.append(self.special_tokens_decoder[i])
  581. else:
  582. tokens.append(self.decoder[i])
  583. return tokens
  584. def encode(self, text):
  585. return self.convert_tokens_to_ids(self.tokenize(text))
  586. def decode(self, tokens):
  587. text = ''.join([self.decoder[token] for token in tokens])
  588. text = bytearray([self.byte_decoder[c] for c in text]).decode(
  589. 'utf-8', errors=self.errors)
  590. return text