text2text_generation.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. import enum
  2. import warnings
  3. from typing import Any, Union
  4. from ..generation import GenerationConfig
  5. from ..tokenization_utils import TruncationStrategy
  6. from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging
  7. from .base import Pipeline, build_pipeline_init_args
  8. if is_tf_available():
  9. import tensorflow as tf
  10. from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
  11. if is_torch_available():
  12. from ..models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
  13. logger = logging.get_logger(__name__)
  14. class ReturnType(enum.Enum):
  15. TENSORS = 0
  16. TEXT = 1
  17. @add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
  18. class Text2TextGenerationPipeline(Pipeline):
  19. """
  20. Pipeline for text to text generation using seq2seq models.
  21. Unless the model you're using explicitly sets these generation parameters in its configuration files
  22. (`generation_config.json`), the following default values will be used:
  23. - max_new_tokens: 256
  24. - num_beams: 4
  25. Example:
  26. ```python
  27. >>> from transformers import pipeline
  28. >>> generator = pipeline(model="mrm8488/t5-base-finetuned-question-generation-ap")
  29. >>> generator(
  30. ... "answer: Manuel context: Manuel has created RuPERTa-base with the support of HF-Transformers and Google"
  31. ... )
  32. [{'generated_text': 'question: Who created the RuPERTa-base?'}]
  33. ```
  34. Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial). You can pass text
  35. generation parameters to this pipeline to control stopping criteria, decoding strategy, and more. Learn more about
  36. text generation parameters in [Text generation strategies](../generation_strategies) and [Text
  37. generation](text_generation).
  38. This Text2TextGenerationPipeline pipeline can currently be loaded from [`pipeline`] using the following task
  39. identifier: `"text2text-generation"`.
  40. The models that this pipeline can use are models that have been fine-tuned on a translation task. See the
  41. up-to-date list of available models on
  42. [huggingface.co/models](https://huggingface.co/models?filter=text2text-generation). For a list of available
  43. parameters, see the [following
  44. documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate)
  45. Usage:
  46. ```python
  47. text2text_generator = pipeline("text2text-generation")
  48. text2text_generator("question: What is 42 ? context: 42 is the answer to life, the universe and everything")
  49. ```"""
  50. _pipeline_calls_generate = True
  51. _load_processor = False
  52. _load_image_processor = False
  53. _load_feature_extractor = False
  54. _load_tokenizer = True
  55. # Make sure the docstring is updated when the default generation config is changed (in all pipelines in this file)
  56. _default_generation_config = GenerationConfig(
  57. max_new_tokens=256,
  58. num_beams=4,
  59. )
  60. # Used in the return key of the pipeline.
  61. return_name = "generated"
  62. def __init__(self, *args, **kwargs):
  63. super().__init__(*args, **kwargs)
  64. self.check_model_type(
  65. TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
  66. if self.framework == "tf"
  67. else MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
  68. )
  69. def _sanitize_parameters(
  70. self,
  71. return_tensors=None,
  72. return_text=None,
  73. return_type=None,
  74. clean_up_tokenization_spaces=None,
  75. truncation=None,
  76. stop_sequence=None,
  77. **generate_kwargs,
  78. ):
  79. preprocess_params = {}
  80. if truncation is not None:
  81. preprocess_params["truncation"] = truncation
  82. forward_params = generate_kwargs
  83. postprocess_params = {}
  84. if return_tensors is not None and return_type is None:
  85. return_type = ReturnType.TENSORS if return_tensors else ReturnType.TEXT
  86. if return_type is not None:
  87. postprocess_params["return_type"] = return_type
  88. if clean_up_tokenization_spaces is not None:
  89. postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces
  90. if stop_sequence is not None:
  91. stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
  92. if len(stop_sequence_ids) > 1:
  93. warnings.warn(
  94. "Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
  95. " the stop sequence will be used as the stop sequence string in the interim."
  96. )
  97. generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
  98. if self.assistant_model is not None:
  99. forward_params["assistant_model"] = self.assistant_model
  100. if self.assistant_tokenizer is not None:
  101. forward_params["tokenizer"] = self.tokenizer
  102. forward_params["assistant_tokenizer"] = self.assistant_tokenizer
  103. return preprocess_params, forward_params, postprocess_params
  104. def check_inputs(self, input_length: int, min_length: int, max_length: int):
  105. """
  106. Checks whether there might be something wrong with given input with regard to the model.
  107. """
  108. return True
  109. def _parse_and_tokenize(self, *args, truncation):
  110. prefix = self.prefix if self.prefix is not None else ""
  111. if isinstance(args[0], list):
  112. if self.tokenizer.pad_token_id is None:
  113. raise ValueError("Please make sure that the tokenizer has a pad_token_id when using a batch input")
  114. args = ([prefix + arg for arg in args[0]],)
  115. padding = True
  116. elif isinstance(args[0], str):
  117. args = (prefix + args[0],)
  118. padding = False
  119. else:
  120. raise TypeError(
  121. f" `args[0]`: {args[0]} have the wrong format. The should be either of type `str` or type `list`"
  122. )
  123. inputs = self.tokenizer(*args, padding=padding, truncation=truncation, return_tensors=self.framework)
  124. # This is produced by tokenizers but is an invalid generate kwargs
  125. if "token_type_ids" in inputs:
  126. del inputs["token_type_ids"]
  127. return inputs
  128. def __call__(self, *args: Union[str, list[str]], **kwargs: Any) -> list[dict[str, str]]:
  129. r"""
  130. Generate the output text(s) using text(s) given as inputs.
  131. Args:
  132. args (`str` or `list[str]`):
  133. Input text for the encoder.
  134. return_tensors (`bool`, *optional*, defaults to `False`):
  135. Whether or not to include the tensors of predictions (as token indices) in the outputs.
  136. return_text (`bool`, *optional*, defaults to `True`):
  137. Whether or not to include the decoded texts in the outputs.
  138. clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
  139. Whether or not to clean up the potential extra spaces in the text output.
  140. truncation (`TruncationStrategy`, *optional*, defaults to `TruncationStrategy.DO_NOT_TRUNCATE`):
  141. The truncation strategy for the tokenization within the pipeline. `TruncationStrategy.DO_NOT_TRUNCATE`
  142. (default) will never truncate, but it is sometimes desirable to truncate the input to fit the model's
  143. max_length instead of throwing an error down the line.
  144. generate_kwargs:
  145. Additional keyword arguments to pass along to the generate method of the model (see the generate method
  146. corresponding to your framework [here](./text_generation)).
  147. Return:
  148. A list or a list of list of `dict`: Each result comes as a dictionary with the following keys:
  149. - **generated_text** (`str`, present when `return_text=True`) -- The generated text.
  150. - **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
  151. ids of the generated text.
  152. """
  153. result = super().__call__(*args, **kwargs)
  154. if (
  155. isinstance(args[0], list)
  156. and all(isinstance(el, str) for el in args[0])
  157. and all(len(res) == 1 for res in result)
  158. ):
  159. return [res[0] for res in result]
  160. return result
  161. def preprocess(self, inputs, truncation=TruncationStrategy.DO_NOT_TRUNCATE, **kwargs):
  162. inputs = self._parse_and_tokenize(inputs, truncation=truncation, **kwargs)
  163. return inputs
  164. def _forward(self, model_inputs, **generate_kwargs):
  165. if self.framework == "pt":
  166. in_b, input_length = model_inputs["input_ids"].shape
  167. elif self.framework == "tf":
  168. in_b, input_length = tf.shape(model_inputs["input_ids"]).numpy()
  169. self.check_inputs(
  170. input_length,
  171. generate_kwargs.get("min_length", self.generation_config.min_length),
  172. generate_kwargs.get("max_length", self.generation_config.max_length),
  173. )
  174. # User-defined `generation_config` passed to the pipeline call take precedence
  175. if "generation_config" not in generate_kwargs:
  176. generate_kwargs["generation_config"] = self.generation_config
  177. output_ids = self.model.generate(**model_inputs, **generate_kwargs)
  178. out_b = output_ids.shape[0]
  179. if self.framework == "pt":
  180. output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:])
  181. elif self.framework == "tf":
  182. output_ids = tf.reshape(output_ids, (in_b, out_b // in_b, *output_ids.shape[1:]))
  183. return {"output_ids": output_ids}
  184. def postprocess(self, model_outputs, return_type=ReturnType.TEXT, clean_up_tokenization_spaces=False):
  185. records = []
  186. for output_ids in model_outputs["output_ids"][0]:
  187. if return_type == ReturnType.TENSORS:
  188. record = {f"{self.return_name}_token_ids": output_ids}
  189. elif return_type == ReturnType.TEXT:
  190. record = {
  191. f"{self.return_name}_text": self.tokenizer.decode(
  192. output_ids,
  193. skip_special_tokens=True,
  194. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  195. )
  196. }
  197. records.append(record)
  198. return records
  199. @add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
  200. class SummarizationPipeline(Text2TextGenerationPipeline):
  201. """
  202. Summarize news articles and other documents.
  203. This summarizing pipeline can currently be loaded from [`pipeline`] using the following task identifier:
  204. `"summarization"`.
  205. The models that this pipeline can use are models that have been fine-tuned on a summarization task, which is
  206. currently, '*bart-large-cnn*', '*google-t5/t5-small*', '*google-t5/t5-base*', '*google-t5/t5-large*', '*google-t5/t5-3b*', '*google-t5/t5-11b*'. See the up-to-date
  207. list of available models on [huggingface.co/models](https://huggingface.co/models?filter=summarization). For a list
  208. of available parameters, see the [following
  209. documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate)
  210. Unless the model you're using explicitly sets these generation parameters in its configuration files
  211. (`generation_config.json`), the following default values will be used:
  212. - max_new_tokens: 256
  213. - num_beams: 4
  214. Usage:
  215. ```python
  216. # use bart in pytorch
  217. summarizer = pipeline("summarization")
  218. summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20)
  219. # use t5 in tf
  220. summarizer = pipeline("summarization", model="google-t5/t5-base", tokenizer="google-t5/t5-base", framework="tf")
  221. summarizer("An apple a day, keeps the doctor away", min_length=5, max_length=20)
  222. ```"""
  223. # Used in the return key of the pipeline.
  224. return_name = "summary"
  225. def __call__(self, *args, **kwargs):
  226. r"""
  227. Summarize the text(s) given as inputs.
  228. Args:
  229. documents (*str* or `list[str]`):
  230. One or several articles (or one list of articles) to summarize.
  231. return_text (`bool`, *optional*, defaults to `True`):
  232. Whether or not to include the decoded texts in the outputs
  233. return_tensors (`bool`, *optional*, defaults to `False`):
  234. Whether or not to include the tensors of predictions (as token indices) in the outputs.
  235. clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
  236. Whether or not to clean up the potential extra spaces in the text output.
  237. generate_kwargs:
  238. Additional keyword arguments to pass along to the generate method of the model (see the generate method
  239. corresponding to your framework [here](./text_generation)).
  240. Return:
  241. A list or a list of list of `dict`: Each result comes as a dictionary with the following keys:
  242. - **summary_text** (`str`, present when `return_text=True`) -- The summary of the corresponding input.
  243. - **summary_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
  244. ids of the summary.
  245. """
  246. return super().__call__(*args, **kwargs)
  247. def check_inputs(self, input_length: int, min_length: int, max_length: int) -> bool:
  248. """
  249. Checks whether there might be something wrong with given input with regard to the model.
  250. """
  251. if max_length < min_length:
  252. logger.warning(f"Your min_length={min_length} must be inferior than your max_length={max_length}.")
  253. if input_length < max_length:
  254. logger.warning(
  255. f"Your max_length is set to {max_length}, but your input_length is only {input_length}. Since this is "
  256. "a summarization task, where outputs shorter than the input are typically wanted, you might "
  257. f"consider decreasing max_length manually, e.g. summarizer('...', max_length={input_length // 2})"
  258. )
  259. @add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
  260. class TranslationPipeline(Text2TextGenerationPipeline):
  261. """
  262. Translates from one language to another.
  263. This translation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
  264. `"translation_xx_to_yy"`.
  265. The models that this pipeline can use are models that have been fine-tuned on a translation task. See the
  266. up-to-date list of available models on [huggingface.co/models](https://huggingface.co/models?filter=translation).
  267. For a list of available parameters, see the [following
  268. documentation](https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.generation.GenerationMixin.generate)
  269. Unless the model you're using explicitly sets these generation parameters in its configuration files
  270. (`generation_config.json`), the following default values will be used:
  271. - max_new_tokens: 256
  272. - num_beams: 4
  273. Usage:
  274. ```python
  275. en_fr_translator = pipeline("translation_en_to_fr")
  276. en_fr_translator("How old are you?")
  277. ```"""
  278. # Used in the return key of the pipeline.
  279. return_name = "translation"
  280. def check_inputs(self, input_length: int, min_length: int, max_length: int):
  281. if input_length > 0.9 * max_length:
  282. logger.warning(
  283. f"Your input_length: {input_length} is bigger than 0.9 * max_length: {max_length}. You might consider "
  284. "increasing your max_length manually, e.g. translator('...', max_length=400)"
  285. )
  286. return True
  287. def preprocess(self, *args, truncation=TruncationStrategy.DO_NOT_TRUNCATE, src_lang=None, tgt_lang=None):
  288. if getattr(self.tokenizer, "_build_translation_inputs", None):
  289. return self.tokenizer._build_translation_inputs(
  290. *args, return_tensors=self.framework, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang
  291. )
  292. else:
  293. return super()._parse_and_tokenize(*args, truncation=truncation)
  294. def _sanitize_parameters(self, src_lang=None, tgt_lang=None, **kwargs):
  295. preprocess_params, forward_params, postprocess_params = super()._sanitize_parameters(**kwargs)
  296. if src_lang is not None:
  297. preprocess_params["src_lang"] = src_lang
  298. if tgt_lang is not None:
  299. preprocess_params["tgt_lang"] = tgt_lang
  300. if src_lang is None and tgt_lang is None:
  301. # Backward compatibility, direct arguments use is preferred.
  302. task = kwargs.get("task", self.task)
  303. items = task.split("_")
  304. if task and len(items) == 4:
  305. # translation, XX, to YY
  306. preprocess_params["src_lang"] = items[1]
  307. preprocess_params["tgt_lang"] = items[3]
  308. return preprocess_params, forward_params, postprocess_params
  309. def __call__(self, *args, **kwargs):
  310. r"""
  311. Translate the text(s) given as inputs.
  312. Args:
  313. args (`str` or `list[str]`):
  314. Texts to be translated.
  315. return_tensors (`bool`, *optional*, defaults to `False`):
  316. Whether or not to include the tensors of predictions (as token indices) in the outputs.
  317. return_text (`bool`, *optional*, defaults to `True`):
  318. Whether or not to include the decoded texts in the outputs.
  319. clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
  320. Whether or not to clean up the potential extra spaces in the text output.
  321. src_lang (`str`, *optional*):
  322. The language of the input. Might be required for multilingual models. Will not have any effect for
  323. single pair translation models
  324. tgt_lang (`str`, *optional*):
  325. The language of the desired output. Might be required for multilingual models. Will not have any effect
  326. for single pair translation models
  327. generate_kwargs:
  328. Additional keyword arguments to pass along to the generate method of the model (see the generate method
  329. corresponding to your framework [here](./text_generation)).
  330. Return:
  331. A list or a list of list of `dict`: Each result comes as a dictionary with the following keys:
  332. - **translation_text** (`str`, present when `return_text=True`) -- The translation.
  333. - **translation_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The
  334. token ids of the translation.
  335. """
  336. return super().__call__(*args, **kwargs)