processing_colpali.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/colpali/modular_colpali.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_colpali.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2024 The HuggingFace Inc. team.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from typing import Optional, Union
  22. from ...feature_extraction_utils import BatchFeature
  23. from ...image_utils import ImageInput, make_flat_list_of_images
  24. from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
  25. from ...tokenization_utils_base import AddedToken, PreTokenizedInput, TextInput
  26. from ...utils import is_torch_available
  27. if is_torch_available():
  28. import torch
  29. class ColPaliProcessorKwargs(ProcessingKwargs, total=False):
  30. _defaults = {
  31. "text_kwargs": {
  32. "padding": "longest",
  33. },
  34. "images_kwargs": {
  35. "data_format": "channels_first",
  36. "do_convert_rgb": True,
  37. },
  38. "common_kwargs": {"return_tensors": "pt"},
  39. }
  40. IMAGE_TOKEN = "<image>"
  41. EXTRA_TOKENS = [f"<loc{i:0>4}>" for i in range(1024)] + [f"<seg{i:0>3}>" for i in range(128)]
  42. def build_string_from_input(prompt, bos_token, image_seq_len, image_token, num_images):
  43. """
  44. Builds a string from the input prompt and image tokens.
  45. For example, for the call:
  46. build_string_from_input(
  47. prompt="Prefix str"
  48. bos_token="<s>",
  49. image_seq_len=3,
  50. image_token="<im>",
  51. )
  52. The output will be:
  53. "<im><im><im><s>Initial str"
  54. Args:
  55. prompt (`list[Union[str, ImageInput]]`): The input prompt.
  56. bos_token (`str`): The beginning of sentence token.
  57. image_seq_len (`int`): The length of the image sequence.
  58. image_token (`str`): The image token.
  59. num_images (`int`): Number of images in the prompt.
  60. """
  61. return f"{image_token * image_seq_len * num_images}{bos_token}{prompt}\n"
  62. class ColPaliProcessor(ProcessorMixin):
  63. r"""
  64. Constructs a ColPali processor which wraps a PaliGemmaProcessor and special methods to process images and queries, as
  65. well as to compute the late-interaction retrieval score.
  66. [`ColPaliProcessor`] offers all the functionalities of [`PaliGemmaProcessor`]. See the [`~PaliGemmaProcessor.__call__`]
  67. for more information.
  68. Args:
  69. image_processor ([`SiglipImageProcessor`], *optional*):
  70. The image processor is a required input.
  71. tokenizer ([`LlamaTokenizerFast`], *optional*):
  72. The tokenizer is a required input.
  73. chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
  74. in a chat into a tokenizable string.
  75. visual_prompt_prefix (`str`, *optional*, defaults to `"Describe the image."`):
  76. A string that gets tokenized and prepended to the image tokens.
  77. query_prefix (`str`, *optional*, defaults to `"Question: "`):
  78. A prefix to be used for the query.
  79. """
  80. attributes = ["image_processor", "tokenizer"]
  81. image_processor_class = ("SiglipImageProcessor", "SiglipImageProcessorFast")
  82. tokenizer_class = ("GemmaTokenizer", "GemmaTokenizerFast")
  83. def __init__(
  84. self,
  85. image_processor=None,
  86. tokenizer=None,
  87. chat_template=None,
  88. visual_prompt_prefix: str = "Describe the image.",
  89. query_prefix: str = "Question: ",
  90. ):
  91. super().__init__(image_processor, tokenizer, chat_template=chat_template)
  92. if not hasattr(image_processor, "image_seq_length"):
  93. raise ValueError("Image processor is missing an `image_seq_length` attribute.")
  94. self.image_seq_length = image_processor.image_seq_length
  95. if not hasattr(tokenizer, "image_token"):
  96. image_token = AddedToken(IMAGE_TOKEN, normalized=False, special=True)
  97. tokens_to_add = {"additional_special_tokens": [image_token]}
  98. tokenizer.add_special_tokens(tokens_to_add)
  99. self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
  100. self.image_token = IMAGE_TOKEN
  101. else:
  102. self.image_token_id = tokenizer.image_token_id
  103. self.image_token = tokenizer.image_token
  104. tokenizer.add_tokens(EXTRA_TOKENS)
  105. tokenizer.add_bos_token = False
  106. tokenizer.add_eos_token = False
  107. self.visual_prompt_prefix = visual_prompt_prefix
  108. self.query_prefix = query_prefix
  109. def __call__(
  110. self,
  111. images: Optional[ImageInput] = None,
  112. text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
  113. audio=None,
  114. videos=None,
  115. **kwargs: Unpack[ColPaliProcessorKwargs],
  116. ) -> BatchFeature:
  117. """
  118. Main method to prepare for the model either (1) one or several texts, either (2) one or several image(s). This method is a custom
  119. wrapper around the PaliGemmaProcessor's [`~PaliGemmaProcessor.__call__`] method adapted for the ColPali model. It cannot process
  120. both text and images at the same time.
  121. When preparing the text(s), this method forwards the `text` and `kwargs` arguments to LlamaTokenizerFast's
  122. [`~LlamaTokenizerFast.__call__`].
  123. When preparing the image(s), this method forwards the `images` and `kwargs` arguments to SiglipImageProcessor's
  124. [`~SiglipImageProcessor.__call__`].
  125. Please refer to the docstring of the above two methods for more information.
  126. Args:
  127. images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
  128. The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
  129. tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
  130. number of channels, H and W are image height and width.
  131. text (`str`, `list[str]`, `list[list[str]]`):
  132. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  133. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  134. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  135. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  136. If set, will return tensors of a particular framework. Acceptable values are:
  137. - `'tf'`: Return TensorFlow `tf.constant` objects.
  138. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  139. - `'np'`: Return NumPy `np.ndarray` objects.
  140. - `'jax'`: Return JAX `jnp.ndarray` objects.
  141. Returns:
  142. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  143. - **input_ids** -- List of token ids to be fed to a model.
  144. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  145. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  146. `None`).
  147. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  148. """
  149. output_kwargs = self._merge_kwargs(
  150. ColPaliProcessorKwargs,
  151. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  152. **kwargs,
  153. )
  154. suffix = output_kwargs["text_kwargs"].pop("suffix", None)
  155. return_token_type_ids = suffix is not None
  156. if text is None and images is None:
  157. raise ValueError("Either text or images must be provided")
  158. if text is not None and images is not None:
  159. raise ValueError("Only one of text or images can be processed at a time")
  160. if images is not None:
  161. images = self.image_processor.fetch_images(images)
  162. images = make_flat_list_of_images(images)
  163. texts_doc = [self.visual_prompt_prefix] * len(images)
  164. images = [image.convert("RGB") for image in images]
  165. input_strings = [
  166. build_string_from_input(
  167. prompt=prompt,
  168. bos_token=self.tokenizer.bos_token,
  169. image_seq_len=self.image_seq_length,
  170. image_token=IMAGE_TOKEN,
  171. num_images=len(image_list) if isinstance(image_list, list) else 1,
  172. )
  173. for prompt, image_list in zip(texts_doc, images)
  174. ]
  175. pixel_values = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]
  176. # max_length has to account for the image tokens
  177. if output_kwargs["text_kwargs"].get("max_length", None) is not None:
  178. output_kwargs["text_kwargs"]["max_length"] += self.image_seq_length
  179. inputs = self.tokenizer(
  180. input_strings,
  181. return_token_type_ids=False,
  182. **output_kwargs["text_kwargs"],
  183. )
  184. return_data = {**inputs, "pixel_values": pixel_values}
  185. if return_token_type_ids:
  186. labels = inputs["input_ids"].masked_fill(inputs["token_type_ids"] == 0, -100)
  187. return_data.update({"labels": labels})
  188. return BatchFeature(data=return_data)
  189. elif text is not None:
  190. if isinstance(text, str):
  191. text = [text]
  192. elif not (isinstance(text, list) and isinstance(text[0], str)):
  193. raise ValueError("Text must be a string or a list of strings")
  194. if suffix is None:
  195. suffix = self.query_augmentation_token * 10
  196. texts_query: list[str] = []
  197. for query in text:
  198. query = self.tokenizer.bos_token + self.query_prefix + query + suffix + "\n"
  199. texts_query.append(query)
  200. output_kwargs["text_kwargs"]["max_length"] = output_kwargs["text_kwargs"].get("max_length", 50)
  201. batch_query = self.tokenizer(
  202. texts_query,
  203. return_token_type_ids=False,
  204. **output_kwargs["text_kwargs"],
  205. )
  206. return batch_query
  207. def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
  208. """
  209. Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
  210. Args:
  211. image_sizes (list[list[str]], *optional*):
  212. The input sizes formatted as (height, width) per each image.
  213. Returns:
  214. `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
  215. input modalities, along with other useful data.
  216. """
  217. vision_data = {}
  218. if image_sizes is not None:
  219. num_image_tokens = [self.image_seq_length] * len(image_sizes)
  220. num_image_patches = [1] * len(image_sizes)
  221. vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
  222. return MultiModalData(**vision_data)
  223. @property
  224. def query_augmentation_token(self) -> str:
  225. """
  226. Return the query augmentation token.
  227. Query augmentation buffers are used as reasoning buffers during inference.
  228. """
  229. return self.tokenizer.pad_token
  230. def process_images(
  231. self,
  232. images: Optional[ImageInput] = None,
  233. **kwargs: Unpack[ColPaliProcessorKwargs],
  234. ) -> BatchFeature:
  235. """
  236. Prepare for the model one or several image(s). This method is a wrapper around the `__call__` method of the ColPaliProcessor's
  237. [`ColPaliProcessor.__call__`].
  238. This method forwards the `images` and `kwargs` arguments to the image processor.
  239. Args:
  240. images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
  241. The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
  242. tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
  243. number of channels, H and W are image height and width.
  244. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  245. If set, will return tensors of a particular framework. Acceptable values are:
  246. - `'tf'`: Return TensorFlow `tf.constant` objects.
  247. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  248. - `'np'`: Return NumPy `np.ndarray` objects.
  249. - `'jax'`: Return JAX `jnp.ndarray` objects.
  250. Returns:
  251. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  252. - **input_ids** -- List of token ids to be fed to a model.
  253. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  254. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  255. `None`).
  256. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  257. """
  258. return self.__call__(images=images, **kwargs)
  259. def process_queries(
  260. self,
  261. text: Union[TextInput, list[TextInput]],
  262. **kwargs: Unpack[ColPaliProcessorKwargs],
  263. ) -> BatchFeature:
  264. """
  265. Prepare for the model one or several texts. This method is a wrapper around the `__call__` method of the ColPaliProcessor's
  266. [`ColPaliProcessor.__call__`].
  267. This method forwards the `text` and `kwargs` arguments to the tokenizer.
  268. Args:
  269. text (`str`, `list[str]`, `list[list[str]]`):
  270. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  271. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  272. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  273. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  274. If set, will return tensors of a particular framework. Acceptable values are:
  275. - `'tf'`: Return TensorFlow `tf.constant` objects.
  276. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  277. - `'np'`: Return NumPy `np.ndarray` objects.
  278. - `'jax'`: Return JAX `jnp.ndarray` objects.
  279. Returns:
  280. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  281. - **input_ids** -- List of token ids to be fed to a model.
  282. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  283. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  284. `None`).
  285. """
  286. return self.__call__(text=text, **kwargs)
  287. def score_retrieval(
  288. self,
  289. query_embeddings: Union["torch.Tensor", list["torch.Tensor"]],
  290. passage_embeddings: Union["torch.Tensor", list["torch.Tensor"]],
  291. batch_size: int = 128,
  292. output_dtype: Optional["torch.dtype"] = None,
  293. output_device: Union["torch.device", str] = "cpu",
  294. ) -> "torch.Tensor":
  295. """
  296. Compute the late-interaction/MaxSim score (ColBERT-like) for the given multi-vector
  297. query embeddings (`qs`) and passage embeddings (`ps`). For ColPali, a passage is the
  298. image of a document page.
  299. Because the embedding tensors are multi-vector and can thus have different shapes, they
  300. should be fed as:
  301. (1) a list of tensors, where the i-th tensor is of shape (sequence_length_i, embedding_dim)
  302. (2) a single tensor of shape (n_passages, max_sequence_length, embedding_dim) -> usually
  303. obtained by padding the list of tensors.
  304. Args:
  305. query_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Query embeddings.
  306. passage_embeddings (`Union[torch.Tensor, list[torch.Tensor]`): Passage embeddings.
  307. batch_size (`int`, *optional*, defaults to 128): Batch size for computing scores.
  308. output_dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The dtype of the output tensor.
  309. If `None`, the dtype of the input embeddings is used.
  310. output_device (`torch.device` or `str`, *optional*, defaults to "cpu"): The device of the output tensor.
  311. Returns:
  312. `torch.Tensor`: A tensor of shape `(n_queries, n_passages)` containing the scores. The score
  313. tensor is saved on the "cpu" device.
  314. """
  315. if len(query_embeddings) == 0:
  316. raise ValueError("No queries provided")
  317. if len(passage_embeddings) == 0:
  318. raise ValueError("No passages provided")
  319. if query_embeddings[0].device != passage_embeddings[0].device:
  320. raise ValueError("Queries and passages must be on the same device")
  321. if query_embeddings[0].dtype != passage_embeddings[0].dtype:
  322. raise ValueError("Queries and passages must have the same dtype")
  323. if output_dtype is None:
  324. output_dtype = query_embeddings[0].dtype
  325. scores: list[torch.Tensor] = []
  326. for i in range(0, len(query_embeddings), batch_size):
  327. batch_scores: list[torch.Tensor] = []
  328. batch_queries = torch.nn.utils.rnn.pad_sequence(
  329. query_embeddings[i : i + batch_size], batch_first=True, padding_value=0
  330. )
  331. for j in range(0, len(passage_embeddings), batch_size):
  332. batch_passages = torch.nn.utils.rnn.pad_sequence(
  333. passage_embeddings[j : j + batch_size], batch_first=True, padding_value=0
  334. )
  335. batch_scores.append(
  336. torch.einsum("bnd,csd->bcns", batch_queries, batch_passages).max(dim=3)[0].sum(dim=2)
  337. )
  338. scores.append(torch.cat(batch_scores, dim=1).to(output_dtype).to(output_device))
  339. return torch.cat(scores, dim=0)
  340. __all__ = ["ColPaliProcessor"]