image_to_text.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # coding=utf-8
  2. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Any, Union, overload
  16. from ..generation import GenerationConfig
  17. from ..utils import (
  18. add_end_docstrings,
  19. is_tf_available,
  20. is_torch_available,
  21. is_vision_available,
  22. logging,
  23. requires_backends,
  24. )
  25. from .base import Pipeline, build_pipeline_init_args
  26. if is_vision_available():
  27. from PIL import Image
  28. from ..image_utils import load_image
  29. if is_tf_available():
  30. from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
  31. if is_torch_available():
  32. import torch
  33. from ..models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
  34. logger = logging.get_logger(__name__)
  35. @add_end_docstrings(build_pipeline_init_args(has_tokenizer=True, has_image_processor=True))
  36. class ImageToTextPipeline(Pipeline):
  37. """
  38. Image To Text pipeline using a `AutoModelForVision2Seq`. This pipeline predicts a caption for a given image.
  39. Unless the model you're using explicitly sets these generation parameters in its configuration files
  40. (`generation_config.json`), the following default values will be used:
  41. - max_new_tokens: 256
  42. Example:
  43. ```python
  44. >>> from transformers import pipeline
  45. >>> captioner = pipeline(model="ydshieh/vit-gpt2-coco-en")
  46. >>> captioner("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png")
  47. [{'generated_text': 'two birds are standing next to each other '}]
  48. ```
  49. Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
  50. This image to text pipeline can currently be loaded from pipeline() using the following task identifier:
  51. "image-to-text".
  52. See the list of available models on
  53. [huggingface.co/models](https://huggingface.co/models?pipeline_tag=image-to-text).
  54. """
  55. _pipeline_calls_generate = True
  56. _load_processor = False
  57. _load_image_processor = True
  58. _load_feature_extractor = False
  59. _load_tokenizer = True
  60. # Make sure the docstring is updated when the default generation config is changed
  61. _default_generation_config = GenerationConfig(
  62. max_new_tokens=256,
  63. )
  64. def __init__(self, *args, **kwargs):
  65. super().__init__(*args, **kwargs)
  66. requires_backends(self, "vision")
  67. self.check_model_type(
  68. TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
  69. )
  70. def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt=None, timeout=None):
  71. forward_params = {}
  72. preprocess_params = {}
  73. if prompt is not None:
  74. preprocess_params["prompt"] = prompt
  75. if timeout is not None:
  76. preprocess_params["timeout"] = timeout
  77. if max_new_tokens is not None:
  78. forward_params["max_new_tokens"] = max_new_tokens
  79. if generate_kwargs is not None:
  80. if max_new_tokens is not None and "max_new_tokens" in generate_kwargs:
  81. raise ValueError(
  82. "`max_new_tokens` is defined both as an argument and inside `generate_kwargs` argument, please use"
  83. " only 1 version"
  84. )
  85. forward_params.update(generate_kwargs)
  86. if self.assistant_model is not None:
  87. forward_params["assistant_model"] = self.assistant_model
  88. if self.assistant_tokenizer is not None:
  89. forward_params["tokenizer"] = self.tokenizer
  90. forward_params["assistant_tokenizer"] = self.assistant_tokenizer
  91. return preprocess_params, forward_params, {}
  92. @overload
  93. def __call__(self, inputs: Union[str, "Image.Image"], **kwargs: Any) -> list[dict[str, Any]]: ...
  94. @overload
  95. def __call__(self, inputs: Union[list[str], list["Image.Image"]], **kwargs: Any) -> list[list[dict[str, Any]]]: ...
  96. def __call__(self, inputs: Union[str, list[str], "Image.Image", list["Image.Image"]], **kwargs):
  97. """
  98. Assign labels to the image(s) passed as inputs.
  99. Args:
  100. inputs (`str`, `list[str]`, `PIL.Image` or `list[PIL.Image]`):
  101. The pipeline handles three types of images:
  102. - A string containing a HTTP(s) link pointing to an image
  103. - A string containing a local path to an image
  104. - An image loaded in PIL directly
  105. The pipeline accepts either a single image or a batch of images.
  106. max_new_tokens (`int`, *optional*):
  107. The amount of maximum tokens to generate. By default it will use `generate` default.
  108. generate_kwargs (`Dict`, *optional*):
  109. Pass it to send all of these arguments directly to `generate` allowing full control of this function.
  110. timeout (`float`, *optional*, defaults to None):
  111. The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
  112. the call may block forever.
  113. Return:
  114. A list or a list of list of `dict`: Each result comes as a dictionary with the following key:
  115. - **generated_text** (`str`) -- The generated text.
  116. """
  117. # After deprecation of this is completed, remove the default `None` value for `images`
  118. if "images" in kwargs:
  119. inputs = kwargs.pop("images")
  120. if inputs is None:
  121. raise ValueError("Cannot call the image-to-text pipeline without an inputs argument!")
  122. return super().__call__(inputs, **kwargs)
  123. def preprocess(self, image, prompt=None, timeout=None):
  124. image = load_image(image, timeout=timeout)
  125. if prompt is not None:
  126. logger.warning_once(
  127. "Passing `prompt` to the `image-to-text` pipeline is deprecated and will be removed in version 4.48"
  128. " of 🤗 Transformers. Use the `image-text-to-text` pipeline instead",
  129. )
  130. if not isinstance(prompt, str):
  131. raise ValueError(
  132. f"Received an invalid text input, got - {type(prompt)} - but expected a single string. "
  133. "Note also that one single text can be provided for conditional image to text generation."
  134. )
  135. model_type = self.model.config.model_type
  136. if model_type == "git":
  137. model_inputs = self.image_processor(images=image, return_tensors=self.framework)
  138. if self.framework == "pt":
  139. model_inputs = model_inputs.to(self.dtype)
  140. input_ids = self.tokenizer(text=prompt, add_special_tokens=False).input_ids
  141. input_ids = [self.tokenizer.cls_token_id] + input_ids
  142. input_ids = torch.tensor(input_ids).unsqueeze(0)
  143. model_inputs.update({"input_ids": input_ids})
  144. elif model_type == "pix2struct":
  145. model_inputs = self.image_processor(images=image, header_text=prompt, return_tensors=self.framework)
  146. if self.framework == "pt":
  147. model_inputs = model_inputs.to(self.dtype)
  148. elif model_type != "vision-encoder-decoder":
  149. # vision-encoder-decoder does not support conditional generation
  150. model_inputs = self.image_processor(images=image, return_tensors=self.framework)
  151. if self.framework == "pt":
  152. model_inputs = model_inputs.to(self.dtype)
  153. text_inputs = self.tokenizer(prompt, return_tensors=self.framework)
  154. model_inputs.update(text_inputs)
  155. else:
  156. raise ValueError(f"Model type {model_type} does not support conditional text generation")
  157. else:
  158. model_inputs = self.image_processor(images=image, return_tensors=self.framework)
  159. if self.framework == "pt":
  160. model_inputs = model_inputs.to(self.dtype)
  161. if self.model.config.model_type == "git" and prompt is None:
  162. model_inputs["input_ids"] = None
  163. return model_inputs
  164. def _forward(self, model_inputs, **generate_kwargs):
  165. # Git model sets `model_inputs["input_ids"] = None` in `preprocess` (when `prompt=None`). In batch model, the
  166. # pipeline will group them into a list of `None`, which fail `_forward`. Avoid this by checking it first.
  167. if (
  168. "input_ids" in model_inputs
  169. and isinstance(model_inputs["input_ids"], list)
  170. and all(x is None for x in model_inputs["input_ids"])
  171. ):
  172. model_inputs["input_ids"] = None
  173. # User-defined `generation_config` passed to the pipeline call take precedence
  174. if "generation_config" not in generate_kwargs:
  175. generate_kwargs["generation_config"] = self.generation_config
  176. # FIXME: We need to pop here due to a difference in how `generation.py` and `generation.tf_utils.py`
  177. # parse inputs. In the Tensorflow version, `generate` raises an error if we don't use `input_ids` whereas
  178. # the PyTorch version matches it with `self.model.main_input_name` or `self.model.encoder.main_input_name`
  179. # in the `_prepare_model_inputs` method.
  180. inputs = model_inputs.pop(self.model.main_input_name)
  181. model_outputs = self.model.generate(inputs, **model_inputs, **generate_kwargs)
  182. return model_outputs
  183. def postprocess(self, model_outputs):
  184. records = []
  185. for output_ids in model_outputs:
  186. record = {
  187. "generated_text": self.tokenizer.decode(
  188. output_ids,
  189. skip_special_tokens=True,
  190. )
  191. }
  192. records.append(record)
  193. return records