processing_internvl.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. # coding=utf-8
  2. # Copyright 2025 HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Optional, Union
  16. import numpy as np
  17. from ...image_processing_utils import BatchFeature
  18. from ...image_utils import ImageInput, concatenate_list, make_flat_list_of_images
  19. from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
  20. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  21. from ...video_utils import VideoInput
  22. class InternVLImagesKwargs(ImagesKwargs, total=False):
  23. crop_to_patches: Optional[bool]
  24. min_patches: Optional[int]
  25. max_patches: Optional[int]
  26. class InternVLProcessorKwargs(ProcessingKwargs, total=False):
  27. images_kwargs: InternVLImagesKwargs
  28. _defaults = {
  29. "text_kwargs": {
  30. "padding_side": "left",
  31. "return_mm_token_type_ids": False,
  32. },
  33. "images_kwargs": {
  34. "crop_to_patches": True,
  35. },
  36. "videos_kwargs": {
  37. "return_tensors": "pt",
  38. },
  39. }
  40. class InternVLProcessor(ProcessorMixin):
  41. r"""
  42. Constructs a InternVL processor which wraps a [`AutoImageProcessor`] and
  43. [`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and
  44. tokenizer functionalities. See the [`~InternVLProcessor.__call__`] and [`~InternVLProcessor.decode`] for more information.
  45. Args:
  46. image_processor ([`AutoImageProcessor`], *optional*):
  47. The image processor is a required input.
  48. tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*):
  49. The tokenizer is a required input.
  50. video_processor ([`AutoVideoProcessor`], *optional*):
  51. The video processor is a required input.
  52. image_seq_length (`int`, *optional*, defaults to 256):
  53. The number of image token to use per image patch. it should be set so that:
  54. image_seq_length = (config.image_size // config.patch_size) ** 2 * (config.scale_factor**2)
  55. chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
  56. in a chat into a tokenizable string.
  57. """
  58. attributes = ["image_processor", "tokenizer", "video_processor"]
  59. image_processor_class = "AutoImageProcessor"
  60. video_processor_class = "AutoVideoProcessor"
  61. tokenizer_class = "AutoTokenizer"
  62. def __init__(
  63. self,
  64. image_processor=None,
  65. tokenizer=None,
  66. video_processor=None,
  67. image_seq_length: int = 256,
  68. chat_template=None,
  69. **kwargs,
  70. ):
  71. self.image_seq_length = image_seq_length
  72. self.start_image_token = tokenizer.start_image_token
  73. self.end_image_token = tokenizer.end_image_token
  74. self.start_image_token_id = tokenizer.start_image_token_id
  75. self.end_image_token_id = tokenizer.end_image_token_id
  76. self.image_token = tokenizer.context_image_token
  77. self.video_token = tokenizer.video_token
  78. self.image_token_id = tokenizer.context_image_token_id
  79. self.image_ids = [self.image_token_id, self.start_image_token_id, self.end_image_token_id]
  80. super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template, **kwargs)
  81. def _insert_media_placeholders(
  82. self,
  83. text: list[str],
  84. image_pixel_values,
  85. video_pixel_values,
  86. image_num_patches: list[int],
  87. video_num_patches: list[int],
  88. image_num_patches_indices: np.ndarray,
  89. video_num_patches_indices: np.ndarray,
  90. video_patch_indices: np.ndarray,
  91. ):
  92. """
  93. Processes interleaved text with <image> and <video> placeholders, replacing them with appropriate
  94. image and video tokens while keeping track of the patches used.
  95. """
  96. image_index = 0
  97. video_index = 0
  98. processed_text = []
  99. image_video_patches = []
  100. replace_strings = []
  101. # Support interleaved image and video in prompts:
  102. # Processed patches of images and videos are inserted in `image_video_patches` in the order they appear in the prompts
  103. for prompt in text:
  104. new_prompt = prompt
  105. while self.image_token in new_prompt or self.video_token in new_prompt:
  106. if self.image_token in new_prompt and (
  107. self.video_token not in new_prompt
  108. or new_prompt.index(self.image_token) < new_prompt.index(self.video_token)
  109. ):
  110. # Get the slice of patches corresponding to the current image
  111. start_index = image_num_patches_indices[image_index - 1] if image_index > 0 else 0
  112. end_index = image_num_patches_indices[image_index]
  113. image_video_patches.append(image_pixel_values[start_index:end_index])
  114. # Replace the corresponding image placeholder with the correct number of image tokens
  115. new_prompt = new_prompt.replace(self.image_token, "<placeholder>", 1)
  116. replace_strings.append(
  117. f"{self.start_image_token}{self.image_token * self.image_seq_length * image_num_patches[image_index]}{self.end_image_token}"
  118. )
  119. image_index += 1
  120. else:
  121. # Get the slice of patches corresponding to the current video
  122. # Here we need to account for both the multiple video frames and the potential multiple patches per frame
  123. # As of now, InternVL only supports one patch per frame, but we keep the code flexible for future updates
  124. current_patch_index = video_patch_indices[video_index]
  125. end_patch_index = video_patch_indices[video_index + 1]
  126. start_index = video_num_patches_indices[current_patch_index]
  127. end_index = video_num_patches_indices[end_patch_index]
  128. image_video_patches.append(video_pixel_values[start_index:end_index])
  129. # Get the number of patches per frame and replace the video placeholder with the correct number of image tokens
  130. num_patches = list(video_num_patches[current_patch_index:end_patch_index])
  131. video_prompt = "\n".join(
  132. f"Frame{i + 1}: {self.start_image_token}{self.image_token * self.image_seq_length * num_patches[i]}{self.end_image_token}"
  133. for i in range(len(num_patches))
  134. )
  135. replace_strings.append(video_prompt)
  136. new_prompt = new_prompt.replace(self.video_token, "<placeholder>", 1)
  137. video_index += 1
  138. while "<placeholder>" in new_prompt:
  139. replace_str = replace_strings.pop(0)
  140. new_prompt = new_prompt.replace("<placeholder>", replace_str, 1)
  141. processed_text.append(new_prompt)
  142. return processed_text, image_video_patches, image_index, video_index
  143. def __call__(
  144. self,
  145. images: Optional[ImageInput] = None,
  146. text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
  147. audio=None,
  148. videos: Optional[VideoInput] = None,
  149. **kwargs: Unpack[InternVLProcessorKwargs],
  150. ) -> BatchFeature:
  151. """
  152. Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
  153. and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text if `text`
  154. is not `None`, otherwise encode default OCR queries which depends on the `format`, `box`, `color`, `multi_page` and
  155. `crop_to_patches` arguments. To prepare the vision inputs, this method forwards the `images` and `kwargs` arguments to
  156. GotOcr2ImageProcessor's [`~GotOcr2ImageProcessor.__call__`] if `images` is not `None`.
  157. Args:
  158. images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
  159. The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
  160. tensor. Both channels-first and channels-last formats are supported.
  161. text (`str`, `list[str]`, `list[list[str]]`):
  162. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  163. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  164. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  165. videos (`np.ndarray`, `torch.Tensor`, `list[np.ndarray]`, `list[torch.Tensor]`):
  166. The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
  167. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  168. If set, will return tensors of a particular framework. Acceptable values are:
  169. - `'tf'`: Return TensorFlow `tf.constant` objects.
  170. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  171. - `'np'`: Return NumPy `np.ndarray` objects.
  172. - `'jax'`: Return JAX `jnp.ndarray` objects.
  173. Returns:
  174. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  175. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  176. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  177. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  178. `None`).
  179. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  180. """
  181. if text is None:
  182. raise ValueError("You have to specify text.")
  183. output_kwargs = self._merge_kwargs(
  184. InternVLProcessorKwargs,
  185. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  186. **kwargs,
  187. )
  188. if not isinstance(text, (list, tuple)):
  189. text = [text]
  190. # Process images and videos separately, as videos don't support crop_to_patches
  191. image_num_patches = []
  192. image_pixel_values = None
  193. image_num_patches_indices = np.array([0])
  194. if images is not None:
  195. images = self.image_processor.fetch_images(images)
  196. images = make_flat_list_of_images(images)
  197. image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
  198. image_num_patches = image_inputs.pop("num_patches")
  199. image_pixel_values = image_inputs.pop("pixel_values")
  200. image_num_patches_indices = np.cumsum(image_num_patches)
  201. video_num_patches = [] # per frame
  202. video_pixel_values = None
  203. video_patch_indices = np.array([0])
  204. video_num_patches_indices = np.array([0])
  205. if videos is not None:
  206. video_kwargs = output_kwargs["videos_kwargs"]
  207. video_inputs = self.video_processor(videos=videos, **video_kwargs)
  208. video_pixel_values = video_inputs.pop("pixel_values_videos")
  209. batch_size, num_frames, *_ = video_pixel_values.shape
  210. num_frames_per_video = np.full(batch_size, num_frames)
  211. num_frames = sum(num_frames_per_video) # total
  212. video_patch_indices = np.empty(batch_size + 1, int)
  213. video_patch_indices[0] = 0
  214. video_patch_indices[1:] = np.cumsum(num_frames_per_video)
  215. video_num_patches = [1] * num_frames
  216. video_num_patches_indices = np.empty(num_frames + 1, int)
  217. video_num_patches_indices[0] = 0
  218. video_num_patches_indices[1:] = np.cumsum(video_num_patches)
  219. video_pixel_values = video_pixel_values.flatten(0, 1)
  220. image_videos_inputs = {}
  221. if images is not None or videos is not None:
  222. text, image_video_patches, image_index, video_index = self._insert_media_placeholders(
  223. text,
  224. image_pixel_values,
  225. video_pixel_values,
  226. image_num_patches,
  227. video_num_patches,
  228. image_num_patches_indices,
  229. video_num_patches_indices,
  230. video_patch_indices,
  231. )
  232. if images is not None and image_index != len(images):
  233. raise ValueError("Number of image placeholders in the prompt does not match the number of images.")
  234. if videos is not None and video_index != len(num_frames_per_video):
  235. raise ValueError("Number of video placeholders in the prompt does not match the number of videos.")
  236. # Concatenate the interleaved image and video patches (function agnostic to the patches type (list, numpy array, torch tensor))
  237. image_videos_inputs = {"pixel_values": concatenate_list(image_video_patches)}
  238. return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
  239. return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", None)
  240. text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
  241. self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
  242. if return_mm_token_type_ids:
  243. array_ids = np.array(text_inputs["input_ids"])
  244. mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
  245. mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
  246. text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
  247. return BatchFeature(data={**text_inputs, **image_videos_inputs}, tensor_type=return_tensors)
  248. def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
  249. """
  250. Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
  251. Args:
  252. image_sizes (`list[list[int]]`, *optional*):
  253. The input sizes formatted as (height, width) per each image.
  254. Returns:
  255. `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
  256. input modalities, along with other useful data.
  257. """
  258. vision_data = {}
  259. if image_sizes is not None:
  260. images_kwargs = InternVLProcessorKwargs._defaults.get("images_kwargs", {})
  261. images_kwargs.update(kwargs)
  262. num_image_patches = [
  263. self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
  264. for image_size in image_sizes
  265. ]
  266. # Add 2 for BOI and EOI tokens
  267. num_image_tokens = [2 + (self.image_seq_length * num_patches) for num_patches in num_image_patches]
  268. vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
  269. return MultiModalData(**vision_data)
  270. @property
  271. def model_input_names(self):
  272. # Overwritten because InternVL renames video inputs to `pixel_values` before returning
  273. tokenizer_input_names = self.tokenizer.model_input_names
  274. image_processor_input_names = self.image_processor.model_input_names
  275. return tokenizer_input_names + image_processor_input_names
  276. __all__ = ["InternVLProcessor"]