tokenization_canine.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # coding=utf-8
  2. # Copyright Google AI and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Tokenization classes for CANINE."""
  16. from typing import Optional
  17. from ...tokenization_utils import AddedToken, PreTrainedTokenizer
  18. from ...utils import logging
  19. logger = logging.get_logger(__name__)
  20. # Unicode defines 1,114,112 total “codepoints”
  21. UNICODE_VOCAB_SIZE = 1114112
  22. # Below: Constants defining canonical codepoints for special, pseudo-characters.
  23. # Copied from https://github.com/google-research/language/blob/master/language/canine/special_codepoints.py
  24. PAD = 0
  25. CLS = 0xE000
  26. SEP = 0xE001
  27. BOS = 0xE002
  28. MASK = 0xE003
  29. RESERVED = 0xE004
  30. # Maps special codepoints to human-readable names.
  31. SPECIAL_CODEPOINTS: dict[int, str] = {
  32. # Special symbols are represented using codepoints values that are valid,
  33. # but designated as "Private Use", meaning that they will never be assigned
  34. # characters by the Unicode Consortium, and are thus safe for use here.
  35. #
  36. # NOTE: Do *NOT* add any sort of [UNK_CHAR] here. They are explicitly
  37. # excluded and should fail with a hard error.
  38. CLS: "[CLS]",
  39. SEP: "[SEP]",
  40. BOS: "[BOS]",
  41. MASK: "[MASK]",
  42. PAD: "[PAD]",
  43. RESERVED: "[RESERVED]",
  44. }
  45. # Maps special codepoint human-readable names to their codepoint values.
  46. SPECIAL_CODEPOINTS_BY_NAME: dict[str, int] = {name: codepoint for codepoint, name in SPECIAL_CODEPOINTS.items()}
  47. class CanineTokenizer(PreTrainedTokenizer):
  48. r"""
  49. Construct a CANINE tokenizer (i.e. a character splitter). It turns text into a sequence of characters, and then
  50. converts each character into its Unicode code point.
  51. [`CanineTokenizer`] inherits from [`PreTrainedTokenizer`].
  52. Refer to superclass [`PreTrainedTokenizer`] for usage examples and documentation concerning parameters.
  53. Args:
  54. model_max_length (`int`, *optional*, defaults to 2048):
  55. The maximum sentence length the model accepts.
  56. """
  57. def __init__(
  58. self,
  59. bos_token=chr(CLS),
  60. eos_token=chr(SEP),
  61. sep_token=chr(SEP),
  62. cls_token=chr(CLS),
  63. pad_token=chr(PAD),
  64. mask_token=chr(MASK),
  65. add_prefix_space=False,
  66. model_max_length=2048,
  67. **kwargs,
  68. ):
  69. bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
  70. eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
  71. sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
  72. cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
  73. pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
  74. # Mask token behave like a normal word, i.e. include the space before it
  75. mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
  76. # Creates a mapping for looking up the IDs of special symbols.
  77. self._special_codepoints: dict[str, int] = {}
  78. for codepoint, name in SPECIAL_CODEPOINTS.items():
  79. self._special_codepoints[name] = codepoint
  80. # Creates a mapping for looking up the string forms of special symbol IDs.
  81. self._special_codepoint_strings: dict[int, str] = {
  82. codepoint: name for name, codepoint in self._special_codepoints.items()
  83. }
  84. self._unicode_vocab_size = UNICODE_VOCAB_SIZE
  85. self._num_special_tokens = len(self._special_codepoints)
  86. super().__init__(
  87. bos_token=bos_token,
  88. eos_token=eos_token,
  89. sep_token=sep_token,
  90. cls_token=cls_token,
  91. pad_token=pad_token,
  92. mask_token=mask_token,
  93. add_prefix_space=add_prefix_space,
  94. model_max_length=model_max_length,
  95. **kwargs,
  96. )
  97. @property
  98. def vocab_size(self) -> int:
  99. return self._unicode_vocab_size
  100. def get_vocab(self):
  101. vocab = {chr(i): i for i in range(self.vocab_size)}
  102. vocab.update(self.added_tokens_encoder)
  103. return vocab
  104. def _tokenize(self, text: str) -> list[str]:
  105. """Tokenize a string (i.e. perform character splitting)."""
  106. return list(text)
  107. def _convert_token_to_id(self, token: str) -> int:
  108. """Converts a token (i.e. a Unicode character) in an id (i.e. its integer Unicode code point value)."""
  109. try:
  110. return ord(token)
  111. except TypeError:
  112. raise ValueError(f"invalid token: '{token}'")
  113. def _convert_id_to_token(self, index: int) -> str:
  114. """
  115. Converts a Unicode code point (integer) in a token (str). In case it's a special code point, convert to
  116. human-readable format.
  117. """
  118. try:
  119. if index in SPECIAL_CODEPOINTS:
  120. return SPECIAL_CODEPOINTS[index]
  121. return chr(index)
  122. except TypeError:
  123. raise ValueError(f"invalid id: {index}")
  124. def convert_tokens_to_string(self, tokens):
  125. return "".join(tokens)
  126. def build_inputs_with_special_tokens(
  127. self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None
  128. ) -> list[int]:
  129. """
  130. Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
  131. adding special tokens. A CANINE sequence has the following format:
  132. - single sequence: `[CLS] X [SEP]`
  133. - pair of sequences: `[CLS] A [SEP] B [SEP]`
  134. Args:
  135. token_ids_0 (`List[int]`):
  136. List of IDs to which the special tokens will be added.
  137. token_ids_1 (`List[int]`, *optional*):
  138. Optional second list of IDs for sequence pairs.
  139. Returns:
  140. `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
  141. """
  142. sep = [self.sep_token_id]
  143. cls = [self.cls_token_id]
  144. result = cls + token_ids_0 + sep
  145. if token_ids_1 is not None:
  146. result += token_ids_1 + sep
  147. return result
  148. def get_special_tokens_mask(
  149. self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
  150. ) -> list[int]:
  151. """
  152. Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
  153. special tokens using the tokenizer `prepare_for_model` method.
  154. Args:
  155. token_ids_0 (`List[int]`):
  156. List of IDs.
  157. token_ids_1 (`List[int]`, *optional*):
  158. Optional second list of IDs for sequence pairs.
  159. already_has_special_tokens (`bool`, *optional*, defaults to `False`):
  160. Whether or not the token list is already formatted with special tokens for the model.
  161. Returns:
  162. `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
  163. """
  164. if already_has_special_tokens:
  165. return super().get_special_tokens_mask(
  166. token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
  167. )
  168. result = [1] + ([0] * len(token_ids_0)) + [1]
  169. if token_ids_1 is not None:
  170. result += ([0] * len(token_ids_1)) + [1]
  171. return result
  172. # CanineTokenizer has no vocab file
  173. def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None):
  174. return ()
  175. __all__ = ["CanineTokenizer"]