tokenizer.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # Copyright (c) 2022 Zhipu.AI
  2. from typing import List, Union
  3. import torch
  4. from transformers import AutoTokenizer
  5. from transformers.models.gpt2 import GPT2TokenizerFast
  6. def encode_whitespaces(text, start_extra_id: int, max_len: int):
  7. """ Encode whitespaces to extra tokens in GPT-J.
  8. >>> encode_whitespaces('a\\n b\\n c', 10, 10)
  9. 'a\\n<|extratoken_10|>b\\n<|extratoken_11|>c'
  10. """
  11. def push_acc_space(acc_len: int, text: str):
  12. if acc_len == 0:
  13. return text
  14. if acc_len == 1:
  15. return text + ' '
  16. assert acc_len <= max_len, f'Max whitespace run length {max_len}, but found {acc_len}'
  17. extra_id = start_extra_id - 2 + acc_len
  18. extra_token = f'<|extratoken_{extra_id}|>'
  19. return text + extra_token
  20. acc_len = 0
  21. res = ''
  22. for ch in text:
  23. if ch == ' ':
  24. acc_len += 1
  25. if acc_len == max_len:
  26. res = push_acc_space(acc_len, res)
  27. acc_len = 0
  28. else:
  29. res = push_acc_space(acc_len, res)
  30. acc_len = 0
  31. res = res + ch
  32. res = push_acc_space(acc_len, res)
  33. return res
  34. def decode_whitespaces(text: str, start_extra_id: int, max_len: int):
  35. """ Decode the whitespace-encoded strings produced by encode_whitespace.
  36. >>> text = 'a\\n b\\n c'
  37. >>> s, l = 10, 10
  38. >>> text == decode_whitespaces(encode_whitespaces(text, s, l), s, l)
  39. True
  40. """
  41. for l in range(2, max_len + 1): # noqa
  42. token_id = start_extra_id - 2 + l
  43. token = f'<|extratoken_{token_id}|>'
  44. text = text.replace(token, ' ' * l)
  45. return text
  46. class Code13BDictionary(object):
  47. def __init__(
  48. self,
  49. dict_file: str,
  50. extra_token_ids: List[str] = None,
  51. pad_to_vocab_size: int = -1,
  52. ):
  53. self._idx = dict()
  54. self._count = dict()
  55. self._num_symbols = 0
  56. self._symbols = []
  57. self._add_symbol('<s>', 0)
  58. self._add_symbol('<pad>', 0)
  59. self._add_symbol('</s>', 0)
  60. self._add_symbol('<unk>', 0)
  61. self._load_dict(dict_file)
  62. if extra_token_ids is None:
  63. extra_token_ids = [str(x) for x in range(50257, 50400)
  64. ] # follows GPT-J settings
  65. for token_id in extra_token_ids:
  66. self._add_symbol(token_id, 0)
  67. if pad_to_vocab_size > 0:
  68. self._pad_to_vocab_size(pad_to_vocab_size)
  69. def _pad_to_vocab_size(self, vocab_size: int):
  70. num_pad = vocab_size - len(self)
  71. if num_pad <= 0:
  72. return
  73. for i in range(1, num_pad + 1):
  74. self._add_symbol('vocab_pad_token{}'.format(i), 0)
  75. def _load_dict(self, dict_file: str):
  76. with open(dict_file, 'r') as f:
  77. for line in f:
  78. line = line.strip()
  79. if line == '' or line.startswith('#'):
  80. continue
  81. sym, count = line.split()
  82. self._add_symbol(sym, int(count))
  83. def _add_symbol(self, sym: str, count: int):
  84. self._idx[sym] = self._num_symbols
  85. self._count[sym] = count
  86. self._symbols.append(sym)
  87. self._num_symbols += 1
  88. def __len__(self):
  89. return self._num_symbols
  90. def index(self, sym: str):
  91. return self._idx[sym]
  92. def string(self, idx: int):
  93. return self._symbols[idx]
  94. def map_token(self, token: Union[int, str]):
  95. if isinstance(token, int):
  96. token = str(token)
  97. return self.index(token)
  98. def map_tokens(self, tokens):
  99. return [self.map_token(token) for token in tokens]
  100. def decode_tokens(self, tokens):
  101. decoded = [
  102. '50256' if token == 50256 else self.string(token)
  103. for token in tokens
  104. ]
  105. return [int(x) for x in decoded if not x.startswith('vocab_pad_token')]
  106. class CodeGeeXTokenizer(object):
  107. def __init__(
  108. self,
  109. tokenizer: GPT2TokenizerFast = None,
  110. tokenizer_path: str = 'EleutherAI/gpt-j-6B',
  111. start_extra_id: int = 10,
  112. max_len: int = 10,
  113. mode='codegeex-13b',
  114. dict_file: str = None,
  115. ):
  116. self.tokenizer = tokenizer if tokenizer is not None else AutoTokenizer.from_pretrained(
  117. tokenizer_path)
  118. if mode not in ['codegeex-13b', 'codegeex-python-13b']:
  119. raise ValueError(
  120. f"Invalid mode {mode}, choose from ['codegeex-13b', 'codegeex-python-13b']"
  121. )
  122. self.start_extra_id = start_extra_id
  123. self.max_len = max_len
  124. self.mode = mode
  125. if dict_file is not None:
  126. self.code_dict = Code13BDictionary(
  127. dict_file, pad_to_vocab_size=51200
  128. ) if self.mode == 'codegeex-python-13b' else None
  129. else:
  130. self.code_dict = None
  131. self.eos_token_id = self.tokenizer.eos_token_id
  132. def encode_code(self, code: str):
  133. if self.mode == 'codegeex-13b':
  134. code = encode_whitespaces(code, self.start_extra_id, self.max_len)
  135. input_ids = self.tokenizer(
  136. code, is_split_into_words=False).input_ids
  137. elif self.mode == 'codegeex-python-13b':
  138. code = encode_whitespaces(code, self.start_extra_id, self.max_len)
  139. input_ids = self.code_dict.map_tokens(self.tokenizer.encode(code))
  140. input_ids = torch.LongTensor(input_ids).reshape(1, -1)
  141. return input_ids
  142. def decode_code(self, input_ids):
  143. if self.mode == 'codegeex-13b':
  144. text = self.tokenizer.decode(input_ids, skip_special_tokens=False)
  145. output_code = decode_whitespaces(text, self.start_extra_id,
  146. self.max_len)
  147. elif self.mode == 'codegeex-python-13b':
  148. input_ids = [self.code_dict.decode_tokens(input_ids.tolist()[0])]
  149. text = self.tokenizer.decode(input_ids, skip_special_tokens=False)
  150. output_code = decode_whitespaces(text, self.start_extra_id,
  151. self.max_len)
  152. return output_code