processing_mllama.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. # coding=utf-8
  2. # Copyright 2024 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Processor class for Mllama."""
  16. from typing import Optional, Union
  17. import numpy as np
  18. from ...feature_extraction_utils import BatchFeature
  19. from ...image_utils import ImageInput, make_nested_list_of_images
  20. from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
  21. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  22. class MllamaImagesKwargs(ImagesKwargs, total=False):
  23. max_image_tiles: Optional[int]
  24. class MllamaProcessorKwargs(ProcessingKwargs, total=False):
  25. images_kwargs: MllamaImagesKwargs
  26. _defaults = {
  27. "image_kwargs": {
  28. "max_image_tiles": 4,
  29. },
  30. }
  31. def get_cross_attention_token_mask(input_ids: list[int], image_token_id: int) -> list[list[int]]:
  32. """
  33. Generate a cross-attention token mask for image tokens in the input sequence.
  34. This function identifies the positions of image tokens in the input sequence and creates
  35. a mask that defines which subsequent tokens each image token should attend to.
  36. Args:
  37. input_ids (list[int]): A list of token ids representing the input sequence.
  38. image_token_id (int): The id of the token used to represent images in the sequence.
  39. Returns:
  40. list[list[int]]: A list of [start, end] pairs, where each pair represents the range
  41. of tokens an image token should attend to.
  42. Notes:
  43. - If no image tokens are present, an empty list is returned.
  44. - For a single image token, it attends to all subsequent tokens until the end of the sequence.
  45. - For multiple image tokens, each attends to tokens up to the next image token or the end of the sequence.
  46. - Consecutive image tokens are treated as a group and attend to all subsequent tokens together.
  47. """
  48. image_token_locations = [i for i, token in enumerate(input_ids) if token == image_token_id]
  49. if len(image_token_locations) == 0:
  50. return []
  51. # only one image present, unmask until end of sequence
  52. if len(image_token_locations) == 1:
  53. return [[image_token_locations[0], -1]]
  54. vision_masks = [[loc1, loc2] for loc1, loc2 in zip(image_token_locations[:-1], image_token_locations[1:])]
  55. # last image will attend to all subsequent text
  56. vision_masks.append([image_token_locations[-1], len(input_ids)])
  57. # if there are two or more consecutive vision tokens,
  58. # they should all attend to all subsequent
  59. # text present
  60. last_mask_end = vision_masks[-1][1]
  61. for vision_mask in vision_masks[::-1]:
  62. if vision_mask[0] == vision_mask[1] - 1:
  63. vision_mask[1] = last_mask_end
  64. last_mask_end = vision_mask[1]
  65. return vision_masks
  66. def convert_sparse_cross_attention_mask_to_dense(
  67. cross_attention_token_mask: list[list[list[int]]],
  68. num_tiles: list[list[int]],
  69. max_num_tiles: int,
  70. length: int,
  71. ) -> np.ndarray:
  72. """
  73. Convert the cross attention mask indices to a cross attention mask 4D array.
  74. This function takes a sparse representation of cross attention masks and converts it to a dense 4D numpy array.
  75. The sparse representation is a nested list structure that defines attention ranges for each image in each batch item.
  76. Args:
  77. cross_attention_token_mask (list[list[list[int]]]): A nested list structure where:
  78. - The outer list represents the batch dimension.
  79. - The middle list represents different images within each batch item.
  80. - The inner list contains pairs of integers [start, end] representing token ranges for each image.
  81. num_tiles (list[list[int]]): A nested list structure specifying the number of tiles for each image in each batch item.
  82. max_num_tiles (int): The maximum possible number of tiles.
  83. length (int): The total sequence length of the input.
  84. Returns:
  85. np.ndarray: A 4D numpy array of shape (batch_size, length, max_num_images, max_num_tiles)
  86. The array contains `1` where attention is allowed and `0` where it is not.
  87. Note:
  88. - Special handling is done for cases where the end token is -1, which is interpreted as attending to the end of the sequence.
  89. """
  90. batch_size = len(cross_attention_token_mask)
  91. max_num_images = max(len(masks) for masks in cross_attention_token_mask)
  92. cross_attention_mask = np.zeros(
  93. shape=(batch_size, length, max_num_images, max_num_tiles),
  94. dtype=np.int64,
  95. )
  96. for sample_idx, (sample_masks, sample_num_tiles) in enumerate(zip(cross_attention_token_mask, num_tiles)):
  97. for mask_idx, (locations, mask_num_tiles) in enumerate(zip(sample_masks, sample_num_tiles)):
  98. if len(locations) == 2:
  99. start, end = locations
  100. end = min(end, length)
  101. if end == -1:
  102. end = length
  103. cross_attention_mask[sample_idx, start:end, mask_idx, :mask_num_tiles] = 1
  104. return cross_attention_mask
  105. def build_string_from_input(prompt: str, bos_token: str, image_token: str) -> str:
  106. """
  107. Builds a string from the input prompt by adding `bos_token` if not already present.
  108. Args:
  109. prompt (`str`):
  110. The input prompt string.
  111. bos_token (`str`):
  112. The beginning of sentence token to be added.
  113. image_token (`str`):
  114. The image token used to identify the start of an image sequence.
  115. Returns:
  116. str: The modified prompt string with the `bos_token` added if necessary.
  117. Examples:
  118. >>> build_string_from_input("Hello world", "<begin_of_text>", "<|image|>")
  119. '<begin_of_text>Hello world'
  120. >>> build_string_from_input("<|image|>Hello world", "<begin_of_text>", "<|image|>")
  121. '<|image|><begin_of_text>Hello world'
  122. >>> build_string_from_input("<begin_of_text>Hello world", "<begin_of_text>", "<|image|>")
  123. '<begin_of_text>Hello world'
  124. """
  125. if bos_token in prompt:
  126. return prompt
  127. num_image_tokens_on_start = 0
  128. while prompt.startswith(image_token):
  129. prompt = prompt[len(image_token) :]
  130. num_image_tokens_on_start += 1
  131. return f"{image_token * num_image_tokens_on_start}{bos_token}{prompt}"
  132. class MllamaProcessor(ProcessorMixin):
  133. r"""
  134. Constructs a Mllama processor which wraps [`MllamaImageProcessor`] and
  135. [`PretrainedTokenizerFast`] into a single processor that inherits both the image processor and
  136. tokenizer functionalities. See the [`~MllamaProcessor.__call__`] and [`~OwlViTProcessor.decode`] for more
  137. information.
  138. The preferred way of passing kwargs is as a dictionary per modality, see usage example below.
  139. ```python
  140. from transformers import MllamaProcessor
  141. from PIL import Image
  142. processor = MllamaProcessor.from_pretrained("meta-llama/Llama-3.2-11B-Vision")
  143. processor(
  144. images=your_pil_image,
  145. text=["<|image|>If I had to write a haiku for this one"],
  146. images_kwargs = {"size": {"height": 448, "width": 448}},
  147. text_kwargs = {"padding": "right"},
  148. common_kwargs = {"return_tensors": "pt"},
  149. )
  150. ```
  151. Args:
  152. image_processor ([`MllamaImageProcessor`]):
  153. The image processor is a required input.
  154. tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]):
  155. The tokenizer is a required input.
  156. chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
  157. in a chat into a tokenizable string.
  158. """
  159. attributes = ["image_processor", "tokenizer"]
  160. image_processor_class = "MllamaImageProcessor"
  161. tokenizer_class = "PreTrainedTokenizerFast"
  162. def __init__(self, image_processor, tokenizer, chat_template=None):
  163. if not hasattr(tokenizer, "image_token"):
  164. self.image_token = "<|image|>"
  165. self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
  166. else:
  167. self.image_token = tokenizer.image_token
  168. self.image_token_id = tokenizer.image_token_id
  169. self.python_token = "<|python_tag|>"
  170. self.python_token_id = tokenizer.convert_tokens_to_ids(self.python_token)
  171. self.bos_token = tokenizer.bos_token
  172. super().__init__(image_processor, tokenizer, chat_template=chat_template)
  173. def __call__(
  174. self,
  175. images: Optional[ImageInput] = None,
  176. text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
  177. audio=None,
  178. videos=None,
  179. **kwargs: Unpack[MllamaProcessorKwargs],
  180. ) -> BatchFeature:
  181. """
  182. Main method to prepare text(s) and image(s) to be fed as input to the model. This method forwards the `text`
  183. arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
  184. the text. To prepare the image(s), this method forwards the `images` arguments to
  185. MllamaImageProcessor's [`~MllamaImageProcessor.__call__`] if `images` is not `None`. Please refer
  186. to the docstring of the above two methods for more information.
  187. Args:
  188. images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
  189. The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
  190. tensor. Both channels-first and channels-last formats are supported.
  191. text (`str`, `list[str]`, `list[list[str]]`):
  192. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  193. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  194. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  195. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  196. If set, will return tensors of a particular framework. Acceptable values are:
  197. - `'tf'`: Return TensorFlow `tf.constant` objects.
  198. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  199. - `'np'`: Return NumPy `np.ndarray` objects.
  200. - `'jax'`: Return JAX `jnp.ndarray` objects.
  201. Returns:
  202. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  203. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  204. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  205. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  206. `None`).
  207. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  208. TODO: add aspect_ratio_ids and aspect_ratio_mask and cross_attention_mask
  209. """
  210. if text is None and images is None:
  211. raise ValueError("You must specify either text or images.")
  212. output_kwargs = self._merge_kwargs(
  213. MllamaProcessorKwargs,
  214. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  215. **kwargs,
  216. )
  217. text_kwargs = output_kwargs["text_kwargs"]
  218. text_kwargs["return_tensors"] = None
  219. images_kwargs = output_kwargs["images_kwargs"]
  220. common_kwargs = output_kwargs["common_kwargs"]
  221. data = {}
  222. if text is not None:
  223. if isinstance(text, str):
  224. text = [text]
  225. elif not (isinstance(text, (list, tuple)) and all(isinstance(t, str) for t in text)):
  226. raise ValueError("Invalid input text. Please provide a string, or a list of strings")
  227. n_images_in_text = [t.count(self.image_token) for t in text]
  228. text = [build_string_from_input(text_item, self.bos_token, self.image_token) for text_item in text]
  229. _ = text_kwargs.pop("padding_side", None) # hack until padding-side is an accepted kwarg by tokenizers
  230. encoding = self.tokenizer(text, **text_kwargs)
  231. self._check_special_mm_tokens(text, encoding, modalities=["image"])
  232. n_images_in_ids = [token_ids.count(self.image_token_id) for token_ids in encoding["input_ids"]]
  233. data.update(encoding)
  234. n_images_in_images = [0]
  235. if images is not None:
  236. images = self.image_processor.fetch_images(images)
  237. images = make_nested_list_of_images(images)
  238. n_images_in_images = [len(sample) for sample in images]
  239. if text is not None:
  240. if any(batch_img == 0 for batch_img in n_images_in_text) and not all(
  241. batch_img == 0 for batch_img in n_images_in_text
  242. ):
  243. raise ValueError(
  244. "If a batch of text is provided, there should be either no images or at least one image per sample"
  245. )
  246. if sum(n_images_in_text) > 0 and (
  247. n_images_in_images != n_images_in_text or n_images_in_ids != n_images_in_images
  248. ):
  249. if images is None:
  250. raise ValueError("No image were provided, but there are image tokens in the prompt")
  251. else:
  252. add_message = ""
  253. if sum(n_images_in_images) == sum(n_images_in_text) and n_images_in_images != n_images_in_text:
  254. add_message = "Make sure to pass your images as a nested list, where each sub-list holds images per batch"
  255. elif n_images_in_ids != n_images_in_images:
  256. add_message = "If you activated truncation with `max_length`, increase the `max_length` so image tokens aren't cropped."
  257. raise ValueError(
  258. f"The number of image tokens in each text ({n_images_in_text}) should be the same as the "
  259. f"number of provided images per batch ({n_images_in_images}). {add_message}"
  260. )
  261. if images is not None:
  262. image_features = self.image_processor(images, **images_kwargs)
  263. num_tiles = image_features.pop("num_tiles")
  264. data.update(image_features)
  265. # Create cross attention mask
  266. if images is not None and text is not None:
  267. cross_attention_token_mask = [
  268. get_cross_attention_token_mask(token_ids, self.image_token_id) for token_ids in encoding["input_ids"]
  269. ]
  270. cross_attention_mask = convert_sparse_cross_attention_mask_to_dense(
  271. cross_attention_token_mask,
  272. num_tiles=num_tiles,
  273. max_num_tiles=self.image_processor.max_image_tiles,
  274. length=max(len(input_ids) for input_ids in encoding["input_ids"]),
  275. )
  276. data["cross_attention_mask"] = cross_attention_mask
  277. return_tensors = common_kwargs.pop("return_tensors", None)
  278. batch_feature = BatchFeature(data=data, tensor_type=return_tensors)
  279. return batch_feature
  280. def post_process_image_text_to_text(
  281. self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
  282. ):
  283. """
  284. Post-process the output of the model to decode the text.
  285. Args:
  286. generated_outputs (`torch.Tensor` or `np.ndarray`):
  287. The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
  288. or `(sequence_length,)`.
  289. skip_special_tokens (`bool`, *optional*, defaults to `True`):
  290. Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
  291. clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
  292. Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
  293. **kwargs:
  294. Additional arguments to be passed to the tokenizer's `batch_decode method`.
  295. Returns:
  296. `list[str]`: The decoded text.
  297. """
  298. return self.tokenizer.batch_decode(
  299. generated_outputs,
  300. skip_special_tokens=skip_special_tokens,
  301. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  302. **kwargs,
  303. )
  304. @property
  305. def model_input_names(self):
  306. tokenizer_input_names = self.tokenizer.model_input_names
  307. image_processor_input_names = self.image_processor.model_input_names
  308. # Remove `num_tiles`, it is popped and used only when processing. Make a copy of list when removing
  309. # otherwise `self.image_processor.model_input_names` is also modified
  310. image_processor_input_names = [name for name in image_processor_input_names if name != "num_tiles"]
  311. return list(tokenizer_input_names + image_processor_input_names + ["cross_attention_mask"])
  312. __all__ = ["MllamaProcessor"]