tokenizer.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List
  15. from tokenizers import Tokenizer
  16. class JiebaBPETokenizer:
  17. """SentencePiece BPE tokenizer with Jieba integration"""
  18. def __init__(self, tokenizer_json_file):
  19. self.name = 'Jieba BPE Tokenizer'
  20. self.tokenizer = Tokenizer.from_file(tokenizer_json_file)
  21. self.eod_id = self.tokenizer.token_to_id('<|endoftext|>')
  22. try:
  23. import jieba
  24. import logging
  25. jieba.setLogLevel(logging.INFO)
  26. except ImportError:
  27. raise ImportError(
  28. 'You need to install jieba to use JiebaTokenizer. '
  29. 'See https://pypi.org/project/jieba/ for installation.')
  30. self.jieba = jieba
  31. self.new_line = self.vocab['\n']
  32. self.sep_token = self.vocab['<sep>']
  33. @property
  34. def vocab_size(self):
  35. return self.tokenizer.get_vocab_size(with_added_tokens=True)
  36. @property
  37. def vocab(self):
  38. return self.tokenizer.get_vocab(with_added_tokens=True)
  39. @property
  40. def inv_vocab(self):
  41. vocab = self.vocab
  42. inv_vocab = dict()
  43. for key, val in vocab.items():
  44. inv_vocab[val] = key
  45. return inv_vocab
  46. def tokenize(self, text: str, is_code: bool = False) -> List[int]:
  47. """
  48. """
  49. if not is_code:
  50. seg_list = [x for x in self.jieba.cut(text)]
  51. return self.tokenizer.encode(
  52. seg_list, is_pretokenized=True, add_special_tokens=True).ids
  53. else:
  54. return self.tokenizer.encode(
  55. text, is_pretokenized=False, add_special_tokens=True).ids
  56. def detokenize(self, token_ids: List[int], early_stop: bool = True) -> str:
  57. if early_stop and self.sep_token in token_ids:
  58. token_ids = token_ids[:token_ids.index(self.sep_token)]
  59. text = self.tokenizer.decode(token_ids, skip_special_tokens=True)
  60. return text
  61. @property
  62. def eod(self):
  63. return self.eod_id