tokenization_rag.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # coding=utf-8
  2. # Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Tokenization classes for RAG."""
  16. import os
  17. import warnings
  18. from typing import Optional
  19. from ...tokenization_utils_base import BatchEncoding
  20. from ...utils import logging
  21. from .configuration_rag import RagConfig
  22. logger = logging.get_logger(__name__)
  23. class RagTokenizer:
  24. def __init__(self, question_encoder, generator):
  25. self.question_encoder = question_encoder
  26. self.generator = generator
  27. self.current_tokenizer = self.question_encoder
  28. def save_pretrained(self, save_directory):
  29. if os.path.isfile(save_directory):
  30. raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
  31. os.makedirs(save_directory, exist_ok=True)
  32. question_encoder_path = os.path.join(save_directory, "question_encoder_tokenizer")
  33. generator_path = os.path.join(save_directory, "generator_tokenizer")
  34. self.question_encoder.save_pretrained(question_encoder_path)
  35. self.generator.save_pretrained(generator_path)
  36. @classmethod
  37. def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
  38. # dynamically import AutoTokenizer
  39. from ..auto.tokenization_auto import AutoTokenizer
  40. config = kwargs.pop("config", None)
  41. if config is None:
  42. config = RagConfig.from_pretrained(pretrained_model_name_or_path)
  43. question_encoder = AutoTokenizer.from_pretrained(
  44. pretrained_model_name_or_path, config=config.question_encoder, subfolder="question_encoder_tokenizer"
  45. )
  46. generator = AutoTokenizer.from_pretrained(
  47. pretrained_model_name_or_path, config=config.generator, subfolder="generator_tokenizer"
  48. )
  49. return cls(question_encoder=question_encoder, generator=generator)
  50. def __call__(self, *args, **kwargs):
  51. return self.current_tokenizer(*args, **kwargs)
  52. def batch_decode(self, *args, **kwargs):
  53. return self.generator.batch_decode(*args, **kwargs)
  54. def decode(self, *args, **kwargs):
  55. return self.generator.decode(*args, **kwargs)
  56. def _switch_to_input_mode(self):
  57. self.current_tokenizer = self.question_encoder
  58. def _switch_to_target_mode(self):
  59. self.current_tokenizer = self.generator
  60. def prepare_seq2seq_batch(
  61. self,
  62. src_texts: list[str],
  63. tgt_texts: Optional[list[str]] = None,
  64. max_length: Optional[int] = None,
  65. max_target_length: Optional[int] = None,
  66. padding: str = "longest",
  67. return_tensors: Optional[str] = None,
  68. truncation: bool = True,
  69. **kwargs,
  70. ) -> BatchEncoding:
  71. warnings.warn(
  72. "`prepare_seq2seq_batch` is deprecated and will be removed in version 5 of 🤗 Transformers. Use the "
  73. "regular `__call__` method to prepare your inputs and the tokenizer under the `with_target_tokenizer` "
  74. "context manager to prepare your targets. See the documentation of your specific tokenizer for more "
  75. "details",
  76. FutureWarning,
  77. )
  78. if max_length is None:
  79. max_length = self.current_tokenizer.model_max_length
  80. model_inputs = self(
  81. src_texts,
  82. add_special_tokens=True,
  83. return_tensors=return_tensors,
  84. max_length=max_length,
  85. padding=padding,
  86. truncation=truncation,
  87. **kwargs,
  88. )
  89. if tgt_texts is None:
  90. return model_inputs
  91. # Process tgt_texts
  92. if max_target_length is None:
  93. max_target_length = self.current_tokenizer.model_max_length
  94. labels = self(
  95. text_target=tgt_texts,
  96. add_special_tokens=True,
  97. return_tensors=return_tensors,
  98. padding=padding,
  99. max_length=max_target_length,
  100. truncation=truncation,
  101. **kwargs,
  102. )
  103. model_inputs["labels"] = labels["input_ids"]
  104. return model_inputs
  105. __all__ = ["RagTokenizer"]