processing_llama4.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # coding=utf-8
  2. # Copyright 2025 HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Optional, Union
  16. from transformers.processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack
  17. from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
  18. from ...image_processing_utils import BatchFeature
  19. from ...image_utils import ImageInput, make_flat_list_of_images
  20. class Llama4ImagesKwargs(ImagesKwargs, total=False):
  21. max_patches: Optional[int]
  22. resize_to_max_canvas: Optional[bool]
  23. class Llama4ProcessorKwargs(ProcessingKwargs, total=False):
  24. images_kwargs: Llama4ImagesKwargs
  25. _defaults = {
  26. "text_kwargs": {
  27. "padding_side": "left",
  28. },
  29. }
  30. chat_template = "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %} \n {%- if messages[0]['content'] is string %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- else %}\n {#- FIXME: The processor requires an array, always. #}\n {%- set system_message = messages[0]['content'][0]['text']|trim %}\n {%- endif %}\n {%- set messages = messages[1:] %}\n {%- set user_supplied_system_message = true %}\n{%- else %}\n {%- set system_message = \"\" %}\n {%- set user_supplied_system_message = false %}\n{%- endif %}\n\n{#- System message if the user supplied one #}\n{%- if user_supplied_system_message %}\n {{- \"<|header_start|>system<|header_end|>\n\n\" }}\n {%- if tools is not none %}\n {{- \"Environment: ipython\n\" }}\n {%- endif %}\n {%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\n\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\n\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- \"<|eot|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|header_start|>user<|header_end|>\n\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\n\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\n\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\n\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %}\n {{- '<|header_start|>assistant<|header_end|>\n\n' -}}\n {{- '<|python_start|>' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|python_end|>' }}\n {%- for tool_call in message.tool_calls %}\n {{- '{\"name\": \"' + tool_call.function.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.function.arguments | tojson }}\n {{- \"}\" }}\n {%- endfor %}\n {{- \"<|eot|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|header_start|>ipython<|header_end|>\n\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|header_start|>assistant<|header_end|>\n\n' }}\n{%- endif %}\n"
  31. class Llama4Processor(ProcessorMixin):
  32. r"""
  33. Constructs a Llama4 processor which wraps a [`AutoImageProcessor`] and
  34. [`PretrainedTokenizerFast`] tokenizer into a single processor that inherits both the image processor and
  35. tokenizer functionalities. See the [`~Llama4Processor.__call__`] and [`~Llama4Processor.decode`] for more information.
  36. Args:
  37. image_processor ([`AutoImageProcessor`], *optional*):
  38. The image processor is a required input.
  39. tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`], *optional*):
  40. The tokenizer is a required input.
  41. patch_size (`int`, *optional*, defaults to 28):
  42. The size of image patches for tokenization.
  43. img_size (`int`, *optional*, defaults to 364):
  44. The size of the image to be tokenized. This should correspond to the size given to the image processor.
  45. image_token (`str`, *optional*, defaults to `"<|image|>"`):
  46. The token to be used to represent an image in the text.
  47. downsample_factor (`int`, *optional*, defaults to 1):
  48. The factor by which to scale the patch size.
  49. start_of_img_token (`str`, *optional*, defaults to `"<|START_OF_IMG|>"`):
  50. The token to be used to represent the start of an image in the text.
  51. end_of_img_token (`str`, *optional*, defaults to `"<|END_OF_IMG|>"`):
  52. The token to be used to represent the end of an image in the text.
  53. img_patch_token (`str`, *optional*, defaults to `"<|IMG_PATCH|>"`):
  54. The token to be used to represent an image patch in the text.
  55. img_line_break_token (`str`, *optional*, defaults to `"<|IMG_LINE_BREAK|>"`):
  56. The token to be used to represent a line break in the text.
  57. tile_token (`str`, *optional*, defaults to `"TILE"`):
  58. The token to be used to represent an image patch in the text.
  59. tile_global_token (`str`, *optional*, defaults to `"TILE_GLOBAL"`):
  60. The token to be used to represent the cover image in the text.
  61. chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
  62. in a chat into a tokenizable string.
  63. """
  64. attributes = ["image_processor", "tokenizer"]
  65. image_processor_class = "AutoImageProcessor"
  66. tokenizer_class = "AutoTokenizer"
  67. def __init__(
  68. self,
  69. image_processor=None,
  70. tokenizer=None,
  71. patch_size: int = 14,
  72. pixel_shuffle_ratio: float = 0.5,
  73. fake_image_token="<|image|>",
  74. image_token="<|image|>",
  75. start_of_image_token="<|image_start|>",
  76. end_of_image_token="<|image_end|>",
  77. patch_token="<|patch|>",
  78. tile_x_separator_token="<|tile_x_separator|>",
  79. tile_y_separator_token="<|tile_y_separator|>",
  80. chat_template=chat_template,
  81. **kwargs,
  82. ):
  83. super().__init__(image_processor, tokenizer, chat_template=chat_template)
  84. self.downsample_ratio = int(round(1.0 / (pixel_shuffle_ratio**2)))
  85. self.patch_size = patch_size
  86. self.fake_image_token = fake_image_token
  87. self.image_token = image_token
  88. self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
  89. self.start_of_img_token = start_of_image_token
  90. self.end_of_img_token = end_of_image_token
  91. self.img_patch_token = patch_token
  92. self.tile_token = tile_x_separator_token
  93. self.tile_global_token = tile_y_separator_token
  94. def _prompt_split_image(self, aspect_ratio, num_patches_per_chunk):
  95. """
  96. Create a structured string representation of image tokens
  97. Args:
  98. num_patches: Number of patches in the image
  99. Returns:
  100. String with appropriate image tokens
  101. """
  102. img_string = "<|image_start|>"
  103. ratio_h, ratio_w = aspect_ratio
  104. if ratio_h * ratio_w > 1:
  105. for yy in range(ratio_h):
  106. for xx in range(ratio_w):
  107. img_string += "<|patch|>" * num_patches_per_chunk
  108. if xx < ratio_w - 1:
  109. img_string += "<|tile_x_separator|>"
  110. img_string += "<|tile_y_separator|>"
  111. img_string += "<|image|>"
  112. img_string += "<|patch|>" * num_patches_per_chunk
  113. img_string += "<|image_end|>"
  114. return img_string
  115. def __call__(
  116. self,
  117. images: Optional[ImageInput] = None,
  118. text: Optional[Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]]] = None,
  119. audio=None,
  120. videos=None,
  121. **kwargs: Unpack[Llama4ProcessorKwargs],
  122. ) -> BatchFeature:
  123. """
  124. Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
  125. and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] to encode the text.
  126. To prepare the vision inputs, this method forwards the `images` and `kwargs` arguments to
  127. Llama4ImageProcessor's [`~Llama4ImageProcessor.__call__`] if `images` is not `None`.
  128. Args:
  129. images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `list[PIL.Image.Image]`, `list[np.ndarray]`, `list[torch.Tensor]`):
  130. The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
  131. tensor. Both channels-first and channels-last formats are supported.
  132. text (`str`, `list[str]`, `list[list[str]]`):
  133. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  134. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  135. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  136. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  137. If set, will return tensors of a particular framework. Acceptable values are:
  138. - `'tf'`: Return TensorFlow `tf.constant` objects.
  139. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  140. - `'np'`: Return NumPy `np.ndarray` objects.
  141. - `'jax'`: Return JAX `jnp.ndarray` objects.
  142. Returns:
  143. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  144. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  145. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  146. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  147. `None`).
  148. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  149. """
  150. if text is None:
  151. raise ValueError("You have to specify text.")
  152. output_kwargs = self._merge_kwargs(
  153. Llama4ProcessorKwargs,
  154. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  155. **kwargs,
  156. )
  157. if not isinstance(text, (list, tuple)):
  158. text = [text]
  159. # Process images
  160. image_inputs = {}
  161. if images is not None:
  162. images = self.image_processor.fetch_images(images)
  163. images = make_flat_list_of_images(images)
  164. image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
  165. image_height, image_width = image_inputs["pixel_values"][0].shape[-2:]
  166. num_patches_per_chunk = int(
  167. (image_height // self.patch_size) * (image_width // self.patch_size) // self.downsample_ratio
  168. )
  169. aspect_ratios = image_inputs.pop("aspect_ratios")
  170. total_placeholders = sum(prompt.count(self.fake_image_token) for prompt in text)
  171. if total_placeholders != len(images):
  172. raise ValueError(
  173. f"Found {total_placeholders} placeholders across the batch, "
  174. f"but have {len(images)} flattened images."
  175. )
  176. image_index = 0
  177. processed_text = []
  178. for prompt in text:
  179. placeholder_count = prompt.count(self.fake_image_token)
  180. if placeholder_count == 0:
  181. # do nothing if there is no image
  182. processed_text.append(prompt)
  183. continue
  184. prompt_splits = prompt.split(self.fake_image_token)
  185. new_prompt = []
  186. for local_image_index, split_part in enumerate(prompt_splits):
  187. new_prompt.append(split_part)
  188. if local_image_index < placeholder_count:
  189. tokens_for_this_image = self._prompt_split_image(
  190. aspect_ratios[image_index], num_patches_per_chunk
  191. )
  192. image_index += 1
  193. new_prompt.append(tokens_for_this_image)
  194. processed_text.append("".join(new_prompt))
  195. if image_index != len(images):
  196. raise ValueError("Number of image placeholders in the prompt does not match the number of images.")
  197. text = processed_text
  198. return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
  199. text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
  200. self._check_special_mm_tokens(text, text_inputs, modalities=["image"])
  201. return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
  202. __all__ = ["Llama4Processor"]