tokenization_speecht5.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # coding=utf-8
  2. # Copyright 2023 The Facebook Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Tokenization class for SpeechT5."""
  16. import os
  17. from shutil import copyfile
  18. from typing import Any, Optional
  19. import sentencepiece as spm
  20. from ...tokenization_utils import PreTrainedTokenizer
  21. from ...utils import logging
  22. from ...utils.import_utils import requires
  23. from .number_normalizer import EnglishNumberNormalizer
  24. logger = logging.get_logger(__name__)
  25. VOCAB_FILES_NAMES = {"vocab_file": "spm_char.model"}
  26. @requires(backends=("sentencepiece",))
  27. class SpeechT5Tokenizer(PreTrainedTokenizer):
  28. """
  29. Construct a SpeechT5 tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).
  30. This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
  31. this superclass for more information regarding those methods.
  32. Args:
  33. vocab_file (`str`):
  34. [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
  35. contains the vocabulary necessary to instantiate a tokenizer.
  36. bos_token (`str`, *optional*, defaults to `"<s>"`):
  37. The begin of sequence token.
  38. eos_token (`str`, *optional*, defaults to `"</s>"`):
  39. The end of sequence token.
  40. unk_token (`str`, *optional*, defaults to `"<unk>"`):
  41. The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  42. token instead.
  43. pad_token (`str`, *optional*, defaults to `"<pad>"`):
  44. The token used for padding, for example when batching sequences of different lengths.
  45. normalize (`bool`, *optional*, defaults to `False`):
  46. Whether to convert numeric quantities in the text to their spelt-out english counterparts.
  47. sp_model_kwargs (`dict`, *optional*):
  48. Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
  49. SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
  50. to set:
  51. - `enable_sampling`: Enable subword regularization.
  52. - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
  53. - `nbest_size = {0,1}`: No sampling is performed.
  54. - `nbest_size > 1`: samples from the nbest_size results.
  55. - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
  56. using forward-filtering-and-backward-sampling algorithm.
  57. - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
  58. BPE-dropout.
  59. Attributes:
  60. sp_model (`SentencePieceProcessor`):
  61. The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
  62. """
  63. vocab_files_names = VOCAB_FILES_NAMES
  64. model_input_names = ["input_ids", "attention_mask"]
  65. def __init__(
  66. self,
  67. vocab_file,
  68. bos_token="<s>",
  69. eos_token="</s>",
  70. unk_token="<unk>",
  71. pad_token="<pad>",
  72. normalize=False,
  73. sp_model_kwargs: Optional[dict[str, Any]] = None,
  74. **kwargs,
  75. ) -> None:
  76. self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
  77. self.vocab_file = vocab_file
  78. self.normalize = normalize
  79. self._normalizer = None
  80. self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
  81. self.sp_model.Load(vocab_file)
  82. super().__init__(
  83. bos_token=bos_token,
  84. eos_token=eos_token,
  85. unk_token=unk_token,
  86. pad_token=pad_token,
  87. normalize=normalize,
  88. sp_model_kwargs=self.sp_model_kwargs,
  89. **kwargs,
  90. )
  91. def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
  92. normalize = kwargs.pop("normalize", self.normalize)
  93. if is_split_into_words:
  94. text = " " + text
  95. if normalize:
  96. text = self.normalizer(text)
  97. return (text, kwargs)
  98. @property
  99. def vocab_size(self):
  100. return self.sp_model.get_piece_size()
  101. @property
  102. def normalizer(self):
  103. if self._normalizer is None:
  104. self._normalizer = EnglishNumberNormalizer()
  105. return self._normalizer
  106. @normalizer.setter
  107. def normalizer(self, value):
  108. self._normalizer = value
  109. def get_vocab(self):
  110. vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
  111. vocab.update(self.added_tokens_encoder)
  112. return vocab
  113. def __getstate__(self):
  114. state = self.__dict__.copy()
  115. state["sp_model"] = None
  116. return state
  117. def __setstate__(self, d):
  118. self.__dict__ = d
  119. # for backward compatibility
  120. if not hasattr(self, "sp_model_kwargs"):
  121. self.sp_model_kwargs = {}
  122. self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
  123. self.sp_model.Load(self.vocab_file)
  124. def _tokenize(self, text: str) -> list[str]:
  125. """Take as input a string and return a list of strings (tokens) for words/sub-words"""
  126. return self.sp_model.encode(text, out_type=str)
  127. def _convert_token_to_id(self, token):
  128. """Converts a token (str) in an id using the vocab."""
  129. return self.sp_model.piece_to_id(token)
  130. def _convert_id_to_token(self, index):
  131. """Converts an index (integer) in a token (str) using the vocab."""
  132. token = self.sp_model.IdToPiece(index)
  133. return token
  134. # Copied from transformers.models.albert.tokenization_albert.AlbertTokenizer.convert_tokens_to_string
  135. def convert_tokens_to_string(self, tokens):
  136. """Converts a sequence of tokens (string) in a single string."""
  137. current_sub_tokens = []
  138. out_string = ""
  139. prev_is_special = False
  140. for token in tokens:
  141. # make sure that special tokens are not decoded using sentencepiece model
  142. if token in self.all_special_tokens:
  143. if not prev_is_special:
  144. out_string += " "
  145. out_string += self.sp_model.decode(current_sub_tokens) + token
  146. prev_is_special = True
  147. current_sub_tokens = []
  148. else:
  149. current_sub_tokens.append(token)
  150. prev_is_special = False
  151. out_string += self.sp_model.decode(current_sub_tokens)
  152. return out_string.strip()
  153. def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> list[int]:
  154. """Build model inputs from a sequence by appending eos_token_id."""
  155. if token_ids_1 is None:
  156. return token_ids_0 + [self.eos_token_id]
  157. # We don't expect to process pairs, but leave the pair logic for API consistency
  158. return token_ids_0 + token_ids_1 + [self.eos_token_id]
  159. def get_special_tokens_mask(
  160. self, token_ids_0: list[int], token_ids_1: Optional[list[int]] = None, already_has_special_tokens: bool = False
  161. ) -> list[int]:
  162. if already_has_special_tokens:
  163. return super().get_special_tokens_mask(
  164. token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
  165. )
  166. suffix_ones = [1]
  167. if token_ids_1 is None:
  168. return ([0] * len(token_ids_0)) + suffix_ones
  169. return ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
  170. def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
  171. if not os.path.isdir(save_directory):
  172. logger.error(f"Vocabulary path ({save_directory}) should be a directory")
  173. return
  174. out_vocab_file = os.path.join(
  175. save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
  176. )
  177. if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
  178. copyfile(self.vocab_file, out_vocab_file)
  179. elif not os.path.isfile(self.vocab_file):
  180. with open(out_vocab_file, "wb") as fi:
  181. content_spiece_model = self.sp_model.serialized_model_proto()
  182. fi.write(content_spiece_model)
  183. return (out_vocab_file,)
  184. __all__ = ["SpeechT5Tokenizer"]