tokenization_myt5.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. # coding=utf-8
  2. # Copyright 2024
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Tokenization class for model MyT5."""
  16. import json
  17. import os
  18. import warnings
  19. from collections import defaultdict
  20. from typing import Optional, Union
  21. from ...tokenization_utils import AddedToken, PreTrainedTokenizer
  22. from ...utils import logging
  23. logger = logging.get_logger(__name__)
  24. VOCAB_FILES_NAMES = {"vocab_file": "byte_maps.json"}
  25. class ByteRewriter:
  26. """
  27. Byte rewriter class for MyT5 tokenizer.
  28. This class is used to rewrite bytes using a hash tree. The hash tree is constructed from a set of rewriting rules.
  29. Args:
  30. rewriting_rules (`str` or `dict[str, str]`):
  31. A path to a json file containing the rewriting rules or a dictionary containing the rewriting rules.
  32. """
  33. LEAF = "[LEAF]"
  34. def __init__(self, rewriting_rules: Union[str, dict[str, str]]):
  35. if isinstance(rewriting_rules, str):
  36. with open(rewriting_rules, "r") as f:
  37. rewriting_rules = json.load(f)
  38. elif not isinstance(rewriting_rules, dict):
  39. raise TypeError(
  40. f"rewriting_rules should be either a path to json file or a dict, got {type(rewriting_rules)}"
  41. )
  42. self.hash_tree = self.construct_hash_tree(rewriting_rules)
  43. reverse_rewriting_rules = {v: k for k, v in rewriting_rules.items()}
  44. self.reverse_hash_tree = self.construct_hash_tree(reverse_rewriting_rules)
  45. def add_leaf(self, hash_tree: dict[str, Union[dict, list[str]]], byte_in_sequence: str, byte_out_sequence: str):
  46. """
  47. Add a leaf with the output byte sequence to the hash tree.
  48. """
  49. byte_in_list = byte_in_sequence.split(" ")
  50. byte_out_list = byte_out_sequence.split(" ")
  51. tree_pointer = hash_tree
  52. for b in byte_in_list:
  53. if b not in tree_pointer:
  54. tree_pointer[b] = {}
  55. tree_pointer = tree_pointer[b]
  56. tree_pointer[self.LEAF] = byte_out_list
  57. def construct_hash_tree(self, rewriting_rules: dict[str, str]) -> dict[str, Union[dict, list[str]]]:
  58. """
  59. Construct a hash tree for rewritten byte sequences.
  60. """
  61. hash_tree = defaultdict(dict)
  62. for b in (f"{x:02x}" for x in range(256)):
  63. hash_tree[b][self.LEAF] = [b]
  64. for in_sequence, out_sequence in rewriting_rules.items():
  65. self.add_leaf(hash_tree, in_sequence, out_sequence)
  66. return hash_tree
  67. def search_hash_tree(self, byte_sequence: list[str]) -> Union[None, list[str]]:
  68. """
  69. Search the hash tree and return the rewritten byte sequence if found.
  70. """
  71. tree_pointer = self.hash_tree
  72. for b in byte_sequence:
  73. if b in tree_pointer:
  74. tree_pointer = tree_pointer[b]
  75. else:
  76. return None
  77. return tree_pointer[self.LEAF]
  78. def rewrite_bytes(self, in_bytes: list[str], reverse=False) -> list[str]:
  79. """
  80. Rewrite a sequence of bytes using the hash tree.
  81. Args:
  82. in_bytes (`list[str]`): A list of bytes to be rewritten.
  83. reverse (`bool`): If True, decoding is performed with the reverse hash tree.
  84. Returns:
  85. `list[str]`: The rewritten byte sequence.
  86. """
  87. out_bytes = []
  88. b_start = 0
  89. b_end = 0
  90. while b_start < len(in_bytes):
  91. tree_pointer = self.hash_tree if not reverse else self.reverse_hash_tree
  92. for j in range(b_start, len(in_bytes)):
  93. b = in_bytes[j]
  94. if b in tree_pointer:
  95. tree_pointer = tree_pointer[b]
  96. elif j == b_start:
  97. cur_leaf = [b]
  98. b_end = j
  99. break
  100. else:
  101. break
  102. if self.LEAF in tree_pointer:
  103. cur_leaf = tree_pointer[self.LEAF]
  104. b_end = j
  105. out_bytes.extend(cur_leaf)
  106. b_start = b_end + 1
  107. return out_bytes
  108. class MyT5Tokenizer(PreTrainedTokenizer):
  109. """
  110. Construct a MyT5 tokenizer.
  111. This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
  112. this superclass for more information regarding those methods.
  113. Args:
  114. vocab_file (`str`): The file containing the byte rewriting rules.
  115. eos_token (`str`, *optional*, defaults to `"</s>"`):
  116. The end of sequence token.
  117. unk_token (`str`, *optional*, defaults to `"<unk>"`):
  118. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  119. token instead.
  120. pad_token (`str`, *optional*, defaults to `"<pad>"`):
  121. The token used for padding, for example when batching sequences of different lengths.
  122. extra_ids (`int`, *optional*, defaults to 125):
  123. Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are
  124. accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are
  125. indexed from the end of the vocabulary up to beginning ("<extra_id_0>" is the last token in the vocabulary
  126. like in ByT5 preprocessing see
  127. [here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)).
  128. additional_special_tokens (`list[str]`, *optional*):
  129. Additional special tokens used by the tokenizer.
  130. """
  131. model_input_names = ["input_ids", "attention_mask"]
  132. vocab_files_names = VOCAB_FILES_NAMES
  133. def __init__(
  134. self,
  135. vocab_file,
  136. eos_token="</s>",
  137. unk_token="<unk>",
  138. pad_token="<pad>",
  139. extra_ids=125,
  140. additional_special_tokens=None,
  141. **kwargs,
  142. ) -> None:
  143. # Add extra_ids to the special token list
  144. if extra_ids > 0 and additional_special_tokens is None:
  145. additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
  146. elif extra_ids > 0 and additional_special_tokens is not None and len(additional_special_tokens) > 0:
  147. # Check that we have the right number of extra_id special tokens
  148. extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
  149. if extra_tokens != extra_ids:
  150. raise ValueError(
  151. f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are"
  152. " provided to MyT5Tokenizer. In this case the additional_special_tokens must include the"
  153. " extra_ids tokens"
  154. )
  155. pad_token = AddedToken(pad_token, lstrip=True, rstrip=True) if isinstance(pad_token, str) else pad_token
  156. eos_token = AddedToken(eos_token, lstrip=True, rstrip=True) if isinstance(eos_token, str) else eos_token
  157. unk_token = AddedToken(unk_token, lstrip=True, rstrip=True) if isinstance(unk_token, str) else unk_token
  158. # unk token needs to be in the vocab with correct index
  159. self._added_tokens_decoder = {0: pad_token, 1: eos_token, 2: unk_token}
  160. self.offset = len(self._added_tokens_decoder)
  161. self._utf_vocab_size = 2**8 # utf is 8 bits
  162. # Load byte maps
  163. self.byte_maps = json.load(open(vocab_file, "r"))
  164. self.decompose_rewriter = ByteRewriter(self.byte_maps["decompose_map"])
  165. self.merge_rewriter = ByteRewriter(self.byte_maps["merge_map"])
  166. super().__init__(
  167. eos_token=eos_token,
  168. unk_token=unk_token,
  169. pad_token=pad_token,
  170. extra_ids=0,
  171. additional_special_tokens=additional_special_tokens,
  172. **kwargs,
  173. )
  174. @property
  175. def vocab_size(self):
  176. return self._utf_vocab_size
  177. # Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.get_vocab
  178. def get_vocab(self):
  179. vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size + self.offset)}
  180. vocab.update(self.added_tokens_encoder)
  181. return vocab
  182. # Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.get_special_tokens_mask
  183. def get_special_tokens_mask(
  184. self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
  185. ) -> list[int]:
  186. """
  187. Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
  188. special tokens using the tokenizer `prepare_for_model` method.
  189. Args:
  190. token_ids_0 (`list[int]`):
  191. List of IDs.
  192. token_ids_1 (`list[int]`, *optional*):
  193. Optional second list of IDs for sequence pairs.
  194. already_has_special_tokens (`bool`, *optional*, defaults to `False`):
  195. Whether or not the token list is already formatted with special tokens for the model.
  196. Returns:
  197. `list[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
  198. """
  199. if already_has_special_tokens:
  200. return super().get_special_tokens_mask(
  201. token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
  202. )
  203. # normal case: some special tokens
  204. if token_ids_1 is None:
  205. return ([0] * len(token_ids_0)) + [1]
  206. return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
  207. def _add_eos_if_not_present(self, token_ids: list[int]) -> list[int]:
  208. """Do not add eos again if user already added it."""
  209. if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
  210. warnings.warn(
  211. f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated"
  212. " eos tokens being added."
  213. )
  214. return token_ids
  215. else:
  216. return token_ids + [self.eos_token_id]
  217. def create_token_type_ids_from_sequences(
  218. self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
  219. ) -> list[int]:
  220. """
  221. Create a mask from the two sequences passed to be used in a sequence-pair classification task. MyT5 does not
  222. make use of token type ids, therefore a list of zeros is returned.
  223. Args:
  224. token_ids_0 (`list[int]`):
  225. List of IDs.
  226. token_ids_1 (`list[int]`, *optional*):
  227. Optional second list of IDs for sequence pairs.
  228. Returns:
  229. `list[int]`: List of zeros.
  230. """
  231. eos = [self.eos_token_id]
  232. if token_ids_1 is None:
  233. return len(token_ids_0 + eos) * [0]
  234. return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
  235. # Copied from transformers.models.byt5.tokenization_byt5.ByT5Tokenizer.build_inputs_with_special_tokens
  236. def build_inputs_with_special_tokens(
  237. self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
  238. ) -> list[int]:
  239. """
  240. Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
  241. adding special tokens. A sequence has the following format:
  242. - single sequence: `X </s>`
  243. - pair of sequences: `A </s> B </s>`
  244. Args:
  245. token_ids_0 (`list[int]`):
  246. List of IDs to which the special tokens will be added.
  247. token_ids_1 (`list[int]`, *optional*):
  248. Optional second list of IDs for sequence pairs.
  249. Returns:
  250. `list[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
  251. """
  252. token_ids_0 = self._add_eos_if_not_present(token_ids_0)
  253. if token_ids_1 is None:
  254. return token_ids_0
  255. else:
  256. token_ids_1 = self._add_eos_if_not_present(token_ids_1)
  257. return token_ids_0 + token_ids_1
  258. def _tokenize(self, text: str, **kwargs) -> list[str]:
  259. """Take as input a string and return a list of strings (tokens) for words/sub-words.
  260. Represents tokens in two character hex format"""
  261. tokens = [f"{i:02x}" for i in text.encode("utf-8")]
  262. tokens = self.morphological_encode(tokens)
  263. return tokens
  264. def _convert_token_to_id(self, token):
  265. """Converts a token (str) in an id using the vocab."""
  266. if len(token) != 2:
  267. token_id = None
  268. else:
  269. token_id = int(token, 16) + self.offset
  270. return token_id
  271. def _convert_id_to_token(self, index):
  272. """Converts an index (integer) in a token (str) using the vocab."""
  273. token = f"{index - self.offset:02x}"
  274. return token
  275. def morphological_encode(self, indices: list[str]) -> list[str]:
  276. # Decompose and merge morphological sequences
  277. indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=False)
  278. indices = self.merge_rewriter.rewrite_bytes(indices, reverse=False)
  279. return indices
  280. def morphological_decode(self, indices: list[str]) -> list[str]:
  281. # Demerge and compose morphological sequences
  282. indices = self.merge_rewriter.rewrite_bytes(indices, reverse=True)
  283. indices = self.decompose_rewriter.rewrite_bytes(indices, reverse=True)
  284. return indices
  285. def convert_tokens_to_string(self, tokens):
  286. """Converts a sequence of tokens (string) in a single string."""
  287. bstring = b""
  288. out_tokens = []
  289. for token in tokens:
  290. if token in self.added_tokens_decoder:
  291. out_tokens.append(self.added_tokens_decoder[token])
  292. elif token in self.added_tokens_encoder:
  293. out_tokens.append(token)
  294. else:
  295. out_tokens.append(token)
  296. out_tokens = self.morphological_decode(out_tokens)
  297. _added_tokens = set(self.added_tokens_decoder.values()) | set(self.added_tokens_encoder)
  298. for token in out_tokens:
  299. if token in _added_tokens:
  300. bstring += bytes(token, "utf-8")
  301. else:
  302. bstring += bytes.fromhex(token)
  303. string = bstring.decode("utf-8", errors="ignore")
  304. return string
  305. def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
  306. if os.path.isdir(save_directory):
  307. vocab_file = os.path.join(
  308. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
  309. )
  310. else:
  311. vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
  312. with open(vocab_file, "w", encoding="utf-8") as writer:
  313. writer.write(json.dumps(self.byte_maps, indent=2, ensure_ascii=False))
  314. return (vocab_file,)
  315. __all__ = ["MyT5Tokenizer"]