processing_llava.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Processor class for Llava.
  17. """
  18. from typing import Optional, Union
  19. import numpy as np
  20. from ...feature_extraction_utils import BatchFeature
  21. from ...image_utils import ImageInput, get_image_size, to_numpy_array
  22. from ...processing_utils import (
  23. MultiModalData,
  24. ProcessingKwargs,
  25. ProcessorMixin,
  26. Unpack,
  27. )
  28. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  29. from ...utils import logging
  30. logger = logging.get_logger(__name__)
  31. class LlavaProcessorKwargs(ProcessingKwargs, total=False):
  32. _defaults = {
  33. "text_kwargs": {"padding": False, "return_mm_token_type_ids": False},
  34. "images_kwargs": {},
  35. }
  36. class LlavaProcessor(ProcessorMixin):
  37. r"""
  38. Constructs a LLaVa processor which wraps a LLaVa image processor and a LLaMa tokenizer into a single processor.
  39. [`LlavaProcessor`] offers all the functionalities of [`LlavaImageProcessor`] and [`LlamaTokenizerFast`]. See the
  40. [`~LlavaProcessor.__call__`] and [`~LlavaProcessor.decode`] for more information.
  41. Args:
  42. image_processor ([`LlavaImageProcessor`], *optional*):
  43. The image processor is a required input.
  44. tokenizer ([`LlamaTokenizerFast`], *optional*):
  45. The tokenizer is a required input.
  46. patch_size (`int`, *optional*):
  47. Patch size from the vision tower.
  48. vision_feature_select_strategy (`str`, *optional*):
  49. The feature selection strategy used to select the vision feature from the vision backbone.
  50. Should be same as in model's config
  51. chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
  52. in a chat into a tokenizable string.
  53. image_token (`str`, *optional*, defaults to `"<image>"`):
  54. Special token used to denote image location.
  55. num_additional_image_tokens (`int`, *optional*, defaults to 0):
  56. Number of additional tokens added to the image embeddings, such as CLS (+1). If the backbone has no CLS or other
  57. extra tokens appended, no need to set this arg.
  58. """
  59. attributes = ["image_processor", "tokenizer"]
  60. image_processor_class = "AutoImageProcessor"
  61. tokenizer_class = "AutoTokenizer"
  62. def __init__(
  63. self,
  64. image_processor=None,
  65. tokenizer=None,
  66. patch_size=None,
  67. vision_feature_select_strategy=None,
  68. chat_template=None,
  69. image_token="<image>", # set the default and let users change if they have peculiar special tokens in rare cases
  70. num_additional_image_tokens=0,
  71. **kwargs,
  72. ):
  73. self.patch_size = patch_size
  74. self.num_additional_image_tokens = num_additional_image_tokens
  75. self.vision_feature_select_strategy = vision_feature_select_strategy
  76. self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
  77. self.image_token_id = tokenizer.encode(self.image_token, add_special_tokens=False)[0]
  78. super().__init__(image_processor, tokenizer, chat_template=chat_template)
  79. def __call__(
  80. self,
  81. images: Optional[ImageInput] = None,
  82. text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
  83. audio=None,
  84. videos=None,
  85. **kwargs: Unpack[LlavaProcessorKwargs],
  86. ) -> BatchFeature:
  87. """
  88. Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
  89. and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
  90. the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
  91. CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
  92. of the above two methods for more information.
  93. Args:
  94. images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
  95. The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
  96. tensor. Both channels-first and channels-last formats are supported.
  97. text (`str`, `list[str]`, `list[list[str]]`):
  98. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  99. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  100. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  101. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  102. If set, will return tensors of a particular framework. Acceptable values are:
  103. - `'tf'`: Return TensorFlow `tf.constant` objects.
  104. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  105. - `'np'`: Return NumPy `np.ndarray` objects.
  106. - `'jax'`: Return JAX `jnp.ndarray` objects.
  107. Returns:
  108. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  109. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  110. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  111. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  112. `None`).
  113. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  114. """
  115. if images is None and text is None:
  116. raise ValueError("You have to specify at least one of `images` or `text`.")
  117. output_kwargs = self._merge_kwargs(
  118. LlavaProcessorKwargs,
  119. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  120. **kwargs,
  121. )
  122. if images is not None:
  123. image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
  124. else:
  125. image_inputs = {}
  126. if isinstance(text, str):
  127. text = [text]
  128. elif not isinstance(text, list) and not isinstance(text[0], str):
  129. raise TypeError("Invalid input text. Please provide a string, or a list of strings")
  130. # try to expand inputs in processing if we have the necessary parts
  131. prompt_strings = text
  132. if image_inputs.get("pixel_values") is not None:
  133. # Replace the image token with the expanded image token sequence
  134. pixel_values = image_inputs["pixel_values"]
  135. height, width = get_image_size(to_numpy_array(pixel_values[0]))
  136. num_image_tokens = (height // self.patch_size) * (
  137. width // self.patch_size
  138. ) + self.num_additional_image_tokens
  139. if self.vision_feature_select_strategy == "default":
  140. num_image_tokens -= 1
  141. prompt_strings = []
  142. for sample in text:
  143. sample = sample.replace(self.image_token, self.image_token * num_image_tokens)
  144. prompt_strings.append(sample)
  145. return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
  146. return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
  147. text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
  148. self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
  149. if return_mm_token_type_ids:
  150. array_ids = np.array(text_inputs["input_ids"])
  151. mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
  152. mm_token_type_ids[array_ids == self.image_token_id] = 1
  153. text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
  154. return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
  155. def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
  156. """
  157. Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
  158. Args:
  159. image_sizes (`list[list[int]]`, *optional*):
  160. The input sizes formatted as (height, width) per each image.
  161. Returns:
  162. `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
  163. input modalities, along with other useful data.
  164. """
  165. vision_data = {}
  166. if image_sizes is not None:
  167. images_kwargs = LlavaProcessorKwargs._defaults.get("images_kwargs", {})
  168. images_kwargs.update(kwargs)
  169. crop_size = images_kwargs.get("crop_size", None) or self.image_processor.crop_size
  170. resized_height, resized_width = crop_size["height"], crop_size["width"]
  171. num_image_tokens = (resized_height // self.patch_size) * (resized_width // self.patch_size)
  172. num_image_tokens += self.num_additional_image_tokens
  173. if self.vision_feature_select_strategy == "default":
  174. num_image_tokens -= 1
  175. num_image_tokens = [num_image_tokens] * len(image_sizes)
  176. num_image_patches = [1] * len(image_sizes)
  177. vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
  178. return MultiModalData(**vision_data)
  179. __all__ = ["LlavaProcessor"]