tokenization.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. import os
  2. from typing import Dict, List, Optional, Union
  3. from sentencepiece import SentencePieceProcessor
  4. from transformers import PreTrainedTokenizer
  5. from transformers.tokenization_utils_base import BatchEncoding, EncodedInput
  6. from transformers.utils import PaddingStrategy
  7. class SPTokenizer:
  8. def __init__(self, model_path: str):
  9. # reload tokenizer
  10. assert os.path.isfile(model_path), model_path
  11. self.sp_model = SentencePieceProcessor(model_file=model_path)
  12. # BOS / EOS token IDs
  13. self.n_words: int = self.sp_model.vocab_size()
  14. self.bos_id: int = self.sp_model.bos_id()
  15. self.eos_id: int = self.sp_model.eos_id()
  16. self.pad_id: int = self.sp_model.unk_id()
  17. assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
  18. special_tokens = ['[MASK]', '[gMASK]', '[sMASK]', 'sop', 'eop']
  19. self.special_tokens = {}
  20. self.index_special_tokens = {}
  21. for token in special_tokens:
  22. self.special_tokens[token] = self.n_words
  23. self.index_special_tokens[self.n_words] = token
  24. self.n_words += 1
  25. def tokenize(self, s: str):
  26. return self.sp_model.EncodeAsPieces(s)
  27. def encode(self,
  28. s: str,
  29. bos: bool = False,
  30. eos: bool = False) -> List[int]:
  31. assert type(s) is str
  32. t = self.sp_model.encode(s)
  33. if bos:
  34. t = [self.bos_id] + t
  35. if eos:
  36. t = t + [self.eos_id]
  37. return t
  38. def decode(self, t: List[int]) -> str:
  39. return self.sp_model.decode(t)
  40. def decode_tokens(self, tokens: List[str]) -> str:
  41. text = self.sp_model.DecodePieces(tokens)
  42. return text
  43. def convert_token_to_id(self, token):
  44. """ Converts a token (str) in an id using the vocab. """
  45. if token in self.special_tokens:
  46. return self.special_tokens[token]
  47. return self.sp_model.PieceToId(token)
  48. def convert_id_to_token(self, index):
  49. """Converts an index (integer) in a token (str) using the vocab."""
  50. if index in self.index_special_tokens or index in [
  51. self.eos_id, self.bos_id, self.pad_id
  52. ] or index < 0:
  53. return ''
  54. return self.sp_model.IdToPiece(index)
  55. class ChatGLM2Tokenizer(PreTrainedTokenizer):
  56. vocab_files_names = {'vocab_file': 'tokenizer.model'}
  57. model_input_names = ['input_ids', 'attention_mask', 'position_ids']
  58. def __init__(self, vocab_file, padding_side='left', **kwargs):
  59. self.name = 'GLMTokenizer'
  60. self.vocab_file = vocab_file
  61. self.tokenizer = SPTokenizer(vocab_file)
  62. self.special_tokens = {
  63. '<bos>': self.tokenizer.bos_id,
  64. '<eos>': self.tokenizer.eos_id,
  65. '<pad>': self.tokenizer.pad_id
  66. }
  67. super().__init__(padding_side=padding_side, **kwargs)
  68. def get_command(self, token):
  69. if token in self.special_tokens:
  70. return self.special_tokens[token]
  71. assert token in self.tokenizer.special_tokens, f'{token} is not a special token for {self.name}'
  72. return self.tokenizer.special_tokens[token]
  73. @property
  74. def pad_token(self) -> str:
  75. return '<unk>'
  76. @property
  77. def pad_token_id(self):
  78. return self.get_command('<pad>')
  79. @property
  80. def eos_token(self) -> str:
  81. return '</s>'
  82. @property
  83. def eos_token_id(self):
  84. return self.get_command('<eos>')
  85. @property
  86. def vocab_size(self):
  87. return self.tokenizer.n_words
  88. def get_vocab(self):
  89. """ Returns vocab as a dict """
  90. vocab = {
  91. self._convert_id_to_token(i): i
  92. for i in range(self.vocab_size)
  93. }
  94. vocab.update(self.added_tokens_encoder)
  95. return vocab
  96. def _tokenize(self, text, **kwargs):
  97. return self.tokenizer.tokenize(text)
  98. def _convert_token_to_id(self, token):
  99. """ Converts a token (str) in an id using the vocab. """
  100. return self.tokenizer.convert_token_to_id(token)
  101. def _convert_id_to_token(self, index):
  102. """Converts an index (integer) in a token (str) using the vocab."""
  103. return self.tokenizer.convert_id_to_token(index)
  104. def convert_tokens_to_string(self, tokens: List[str]) -> str:
  105. return self.tokenizer.decode_tokens(tokens)
  106. def save_vocabulary(self, save_directory, filename_prefix=None):
  107. """
  108. Save the vocabulary and special tokens file to a directory.
  109. Args:
  110. save_directory (`str`):
  111. The directory in which to save the vocabulary.
  112. filename_prefix (`str`, *optional*):
  113. An optional prefix to add to the named of the saved files.
  114. Returns:
  115. `Tuple(str)`: Paths to the files saved.
  116. """
  117. if os.path.isdir(save_directory):
  118. vocab_file = os.path.join(save_directory,
  119. self.vocab_files_names['vocab_file'])
  120. else:
  121. vocab_file = save_directory
  122. with open(self.vocab_file, 'rb') as fin:
  123. proto_str = fin.read()
  124. with open(vocab_file, 'wb') as writer:
  125. writer.write(proto_str)
  126. return (vocab_file, )
  127. def get_prefix_tokens(self):
  128. prefix_tokens = [self.get_command('[gMASK]'), self.get_command('sop')]
  129. return prefix_tokens
  130. def build_prompt(self, query, history=None):
  131. if history is None:
  132. history = []
  133. prompt = ''
  134. for i, (old_query, response) in enumerate(history):
  135. prompt += '[Round {}]\n\n问:{}\n\n答:{}\n\n'.format(
  136. i + 1, old_query, response)
  137. prompt += '[Round {}]\n\n问:{}\n\n答:'.format(len(history) + 1, query)
  138. return prompt
  139. def build_inputs_with_special_tokens(
  140. self,
  141. token_ids_0: List[int],
  142. token_ids_1: Optional[List[int]] = None) -> List[int]:
  143. """
  144. Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
  145. adding special tokens. A BERT sequence has the following format:
  146. - single sequence: `[CLS] X [SEP]`
  147. - pair of sequences: `[CLS] A [SEP] B [SEP]`
  148. Args:
  149. token_ids_0 (`List[int]`):
  150. List of IDs to which the special tokens will be added.
  151. token_ids_1 (`List[int]`, *optional*):
  152. Optional second list of IDs for sequence pairs.
  153. Returns:
  154. `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
  155. """
  156. prefix_tokens = self.get_prefix_tokens()
  157. token_ids_0 = prefix_tokens + token_ids_0
  158. if token_ids_1 is not None:
  159. token_ids_0 = token_ids_0 + token_ids_1 + [
  160. self.get_command('<eos>')
  161. ]
  162. return token_ids_0
  163. def _pad(
  164. self,
  165. encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
  166. max_length: Optional[int] = None,
  167. padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
  168. pad_to_multiple_of: Optional[int] = None,
  169. return_attention_mask: Optional[bool] = None,
  170. ) -> dict:
  171. """
  172. Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
  173. Args:
  174. encoded_inputs:
  175. Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
  176. max_length: maximum length of the returned list and optionally padding length (see below).
  177. Will truncate by taking into account the special tokens.
  178. padding_strategy: PaddingStrategy to use for padding.
  179. - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
  180. - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
  181. - PaddingStrategy.DO_NOT_PAD: Do not pad
  182. The tokenizer padding sides are defined in self.padding_side:
  183. - 'left': pads on the left of the sequences
  184. - 'right': pads on the right of the sequences
  185. pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
  186. This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
  187. `>= 7.5` (Volta).
  188. return_attention_mask:
  189. (optional) Set to False to avoid returning attention mask (default: set to model specifics)
  190. """
  191. # Load from model defaults
  192. assert self.padding_side == 'left'
  193. required_input = encoded_inputs[self.model_input_names[0]]
  194. seq_length = len(required_input)
  195. if padding_strategy == PaddingStrategy.LONGEST:
  196. max_length = len(required_input)
  197. if max_length is not None and pad_to_multiple_of is not None and (
  198. max_length % pad_to_multiple_of != 0):
  199. max_length = (
  200. (max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
  201. needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(
  202. required_input) != max_length
  203. # Initialize attention mask if not present.
  204. if 'attention_mask' not in encoded_inputs:
  205. encoded_inputs['attention_mask'] = [1] * seq_length
  206. if 'position_ids' not in encoded_inputs:
  207. encoded_inputs['position_ids'] = list(range(seq_length))
  208. if needs_to_be_padded:
  209. difference = max_length - len(required_input)
  210. if 'attention_mask' in encoded_inputs:
  211. encoded_inputs['attention_mask'] = [
  212. 0
  213. ] * difference + encoded_inputs['attention_mask']
  214. if 'position_ids' in encoded_inputs:
  215. encoded_inputs['position_ids'] = [
  216. 0
  217. ] * difference + encoded_inputs['position_ids']
  218. encoded_inputs[self.model_input_names[
  219. 0]] = [self.pad_token_id] * difference + required_input
  220. return encoded_inputs