processing_chameleon.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # coding=utf-8
  2. # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Processor class for Chameleon.
  17. """
  18. from typing import Optional, Union
  19. import numpy as np
  20. from ...feature_extraction_utils import BatchFeature
  21. from ...image_utils import ImageInput
  22. from ...processing_utils import (
  23. MultiModalData,
  24. ProcessingKwargs,
  25. ProcessorMixin,
  26. TextKwargs,
  27. Unpack,
  28. )
  29. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  30. class ChameleonTextKwargs(TextKwargs, total=False):
  31. return_for_text_completion: bool
  32. class ChameleonProcessorKwargs(ProcessingKwargs, total=False):
  33. text_kwargs: ChameleonTextKwargs
  34. _defaults = {
  35. "text_kwargs": {
  36. "padding": False,
  37. "return_for_text_completion": False,
  38. "return_mm_token_type_ids": False,
  39. },
  40. "common_kwargs": {
  41. "return_tensors": "pt",
  42. },
  43. }
  44. class ChameleonProcessor(ProcessorMixin):
  45. r"""
  46. Constructs a Chameleon processor which wraps a Chameleon image processor and a Chameleon tokenizer into a single
  47. processor.
  48. [`ChameleonProcessor`] offers all the functionalities of [`ChameleonImageProcessor`] and [`LlamaTokenizerFast`].
  49. See the [`~ChameleonProcessor.__call__`] and [`~ChameleonProcessor.decode`] for more information.
  50. Args:
  51. image_processor ([`ChameleonImageProcessor`]):
  52. The image processor is a required input.
  53. tokenizer ([`LlamaTokenizerFast`]):
  54. The tokenizer is a required input.
  55. image_seq_length (`int`, *optional*, defaults to 1024):
  56. Sequence length of one image embedding.
  57. image_token (`str`, *optional*, defaults to `"<image>"`):
  58. The special token used to indicate image in the text.
  59. """
  60. attributes = ["image_processor", "tokenizer"]
  61. tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
  62. image_processor_class = "ChameleonImageProcessor"
  63. def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, image_token: str = "<image>"):
  64. self.image_seq_length = image_seq_length
  65. self.image_token = tokenizer.image_token if hasattr(tokenizer, "image_token") else image_token
  66. self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
  67. self.image_start_token = (
  68. tokenizer.boi_token if hasattr(tokenizer, "boi_token") else "<racm3:break>"
  69. ) # fixed tokens for start and end, so can hardcode
  70. self.image_end_token = tokenizer.eoi_token if hasattr(tokenizer, "eoi_token") else "<eoss>"
  71. self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
  72. self.image_start_token_id = tokenizer.convert_tokens_to_ids(self.image_start_token)
  73. self.image_end_token_id = tokenizer.convert_tokens_to_ids(self.image_end_token)
  74. self.image_ids = [self.image_token_id, self.image_start_token_id, self.image_end_token_id]
  75. super().__init__(image_processor, tokenizer)
  76. def __call__(
  77. self,
  78. images: Optional[ImageInput] = None,
  79. text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
  80. audio=None,
  81. videos=None,
  82. **kwargs: Unpack[ChameleonProcessorKwargs],
  83. ) -> BatchFeature:
  84. """
  85. Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
  86. and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to encode
  87. the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
  88. CLIPImageProcessor's [`~CLIPImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
  89. of the above two methods for more information.
  90. Args:
  91. images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
  92. The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
  93. tensor. Both channels-first and channels-last formats are supported.
  94. text (`str`, `list[str]`, `list[list[str]]`):
  95. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  96. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  97. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  98. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  99. If set, will return tensors of a particular framework. Acceptable values are:
  100. - `'tf'`: Return TensorFlow `tf.constant` objects.
  101. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  102. - `'np'`: Return NumPy `np.ndarray` objects.
  103. - `'jax'`: Return JAX `jnp.ndarray` objects.
  104. Returns:
  105. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  106. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  107. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  108. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  109. `None`).
  110. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  111. """
  112. if isinstance(text, str):
  113. text = [text]
  114. elif not isinstance(text, list) and not isinstance(text[0], str):
  115. raise TypeError("Invalid input text. Please provide a string, or a list of strings")
  116. if text is None and images is None:
  117. raise ValueError("You must provide either text or images")
  118. output_kwargs = self._merge_kwargs(
  119. ChameleonProcessorKwargs,
  120. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  121. **kwargs,
  122. )
  123. return_for_text_completion = output_kwargs["text_kwargs"].pop("return_for_text_completion", False)
  124. # Replace the image token with the expanded image token sequence
  125. prompt_strings = []
  126. one_img_tokens = self.image_start_token + (self.image_token * self.image_seq_length) + self.image_end_token
  127. for sample in text:
  128. sample = sample.replace(self.image_token, one_img_tokens)
  129. if not return_for_text_completion:
  130. sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
  131. prompt_strings.append(sample)
  132. image_inputs = {}
  133. if images is not None:
  134. image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
  135. return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
  136. return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
  137. text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
  138. self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
  139. if return_mm_token_type_ids:
  140. array_ids = np.array(text_inputs["input_ids"])
  141. mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
  142. mm_token_type_ids[np.isin(array_ids, self.image_ids)] = 1
  143. text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
  144. return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
  145. def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
  146. """
  147. Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
  148. Args:
  149. image_sizes (`list[list[int]]`, *optional*):
  150. The input sizes formatted as (height, width) per each image.
  151. Returns:
  152. `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
  153. input modalities, along with other useful data.
  154. """
  155. vision_data = {}
  156. if image_sizes is not None:
  157. # add 2 for BOI and EOI tokens
  158. num_image_tokens = [self.image_seq_length + 2] * len(image_sizes)
  159. num_image_patches = [1] * len(image_sizes)
  160. vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
  161. return MultiModalData(**vision_data)
  162. __all__ = ["ChameleonProcessor"]