processing_glm4v.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/glm4v/modular_glm4v.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_glm4v.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from typing import Optional, Union
  22. import numpy as np
  23. from ...feature_extraction_utils import BatchFeature
  24. from ...image_utils import ImageInput
  25. from ...processing_utils import ImagesKwargs, MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs
  26. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  27. from ...utils import logging
  28. from ...video_utils import VideoInput
  29. logger = logging.get_logger(__name__)
  30. class Glm4vVideosProcessorKwargs(VideosKwargs, total=False):
  31. fps: Union[list[float], float]
  32. class Glm4vImagesKwargs(ImagesKwargs):
  33. patch_size: Optional[int]
  34. temporal_patch_size: Optional[int]
  35. merge_size: Optional[int]
  36. class Glm4vProcessorKwargs(ProcessingKwargs, total=False):
  37. images_kwargs: Glm4vImagesKwargs
  38. _defaults = {
  39. "text_kwargs": {
  40. "padding": False,
  41. "return_token_type_ids": False,
  42. "return_mm_token_type_ids": False,
  43. },
  44. "videos_kwargs": {"return_metadata": True},
  45. }
  46. videos_kwargs: Glm4vVideosProcessorKwargs
  47. class Glm4vProcessor(ProcessorMixin):
  48. r"""
  49. Constructs a GLM-4V processor which wraps a GLM-4V image processor and a GLM-4 tokenizer into a single processor.
  50. [`~Glm4vProcessor.__call__`] and [`~Glm4vProcessor.decode`] for more information.
  51. Args:
  52. image_processor ([`Glm4vProcessor`], *optional*):
  53. The image processor is a required input.
  54. tokenizer ([`PreTrainedTokenizerFast`], *optional*):
  55. The tokenizer is a required input.
  56. video_processor ([`Glm4vVideoProcessor`], *optional*):
  57. The video processor is a required input.
  58. chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
  59. in a chat into a tokenizable string.
  60. """
  61. attributes = ["image_processor", "tokenizer", "video_processor"]
  62. image_processor_class = "AutoImageProcessor"
  63. video_processor_class = "AutoVideoProcessor"
  64. tokenizer_class = ("PreTrainedTokenizer", "PreTrainedTokenizerFast")
  65. def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
  66. super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
  67. self.image_token = "<|image|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
  68. self.video_token = "<|video|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
  69. self.image_token_id = (
  70. tokenizer.image_token_id
  71. if getattr(tokenizer, "image_token_id", None)
  72. else tokenizer.convert_tokens_to_ids(self.image_token)
  73. )
  74. self.video_token_id = (
  75. tokenizer.video_token_id
  76. if getattr(tokenizer, "video_token_id", None)
  77. else tokenizer.convert_tokens_to_ids(self.video_token)
  78. )
  79. def __call__(
  80. self,
  81. images: Optional[ImageInput] = None,
  82. text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
  83. videos: Optional[VideoInput] = None,
  84. **kwargs: Unpack[Glm4vProcessorKwargs],
  85. ) -> BatchFeature:
  86. """
  87. Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
  88. and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
  89. the text.
  90. Args:
  91. images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
  92. The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
  93. tensor. Both channels-first and channels-last formats are supported.
  94. text (`str`, `List[str]`, `List[List[str]]`):
  95. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  96. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  97. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  98. videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
  99. The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch
  100. tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported.
  101. return_tensors (`str` or [`~utils.TensorType`], *optional*):
  102. If set, will return tensors of a particular framework. Acceptable values are:
  103. - `'tf'`: Return TensorFlow `tf.constant` objects.
  104. - `'pt'`: Return PyTorch `torch.Tensor` objects.
  105. - `'np'`: Return NumPy `np.ndarray` objects.
  106. - `'jax'`: Return JAX `jnp.ndarray` objects.
  107. Returns:
  108. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  109. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  110. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  111. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  112. `None`).
  113. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  114. - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
  115. - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
  116. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
  117. """
  118. output_kwargs = self._merge_kwargs(
  119. Glm4vProcessorKwargs,
  120. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  121. **kwargs,
  122. )
  123. if images is not None:
  124. image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
  125. image_grid_thw = image_inputs["image_grid_thw"]
  126. else:
  127. image_inputs = {}
  128. image_grid_thw = None
  129. if videos is not None:
  130. videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
  131. # If user has not requested video metadata, pop it
  132. if "return_metadata" not in kwargs:
  133. video_metadata = videos_inputs.pop("video_metadata")
  134. else:
  135. video_metadata = videos_inputs["video_metadata"]
  136. video_grid_thw = videos_inputs["video_grid_thw"]
  137. else:
  138. videos_inputs = {}
  139. video_grid_thw = None
  140. if not isinstance(text, list):
  141. text = [text]
  142. text = text.copy() # below lines change text in-place
  143. if image_grid_thw is not None:
  144. merge_length = self.image_processor.merge_size**2
  145. index = 0
  146. for i in range(len(text)):
  147. while self.image_token in text[i]:
  148. num_image_tokens = image_grid_thw[index].prod() // merge_length
  149. text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
  150. index += 1
  151. text[i] = text[i].replace("<|placeholder|>", self.image_token)
  152. if video_grid_thw is not None:
  153. merge_length = self.video_processor.merge_size**2
  154. video_index = 0
  155. for i in range(len(text)):
  156. while self.video_token in text[i]:
  157. num_frames = video_grid_thw[video_index][0]
  158. video_structure = ""
  159. metadata = video_metadata[video_index]
  160. if metadata.fps is None:
  161. logger.warning_once(
  162. "SmolVLM requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. "
  163. "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. "
  164. "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
  165. )
  166. metadata.fps = 24 if metadata.fps is None else metadata.fps
  167. timestamps = metadata.timestamps[::2] # mrope
  168. unique_timestamps = []
  169. for idx in range(0, len(timestamps)):
  170. unique_timestamps.append(timestamps[idx])
  171. selected_timestamps = unique_timestamps[:num_frames]
  172. while len(selected_timestamps) < num_frames:
  173. selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
  174. for frame_idx in range(num_frames):
  175. timestamp_sec = selected_timestamps[frame_idx]
  176. frame_structure = f"<|begin_of_image|>{self.image_token}<|end_of_image|>{int(timestamp_sec)}"
  177. video_structure += frame_structure
  178. text[i] = text[i].replace(self.video_token, video_structure, 1)
  179. num_image_tokens = (
  180. video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
  181. )
  182. for frame_idx in range(num_frames):
  183. if self.image_token in text[i]:
  184. text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
  185. video_index += 1
  186. text[i] = text[i].replace("<|placeholder|>", self.image_token)
  187. return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
  188. return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
  189. text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
  190. self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
  191. if return_mm_token_type_ids:
  192. array_ids = np.array(text_inputs["input_ids"])
  193. mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
  194. mm_token_type_ids[array_ids == self.image_token_id] = 1
  195. text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
  196. return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
  197. def _get_num_multimodal_tokens(self, image_sizes=None, video_sizes=None, **kwargs):
  198. """
  199. Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
  200. Args:
  201. image_sizes (`list[list[int]]`, *optional*):
  202. The input sizes formatted as (height, width) per each image.
  203. video_sizes (`list[list[int]]`, *optional*):
  204. The input sizes formatted as (num_frames, height, width) per each video.
  205. Returns:
  206. `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
  207. input modalities, along with other useful data.
  208. """
  209. vision_data = {}
  210. if image_sizes is not None:
  211. images_kwargs = Glm4vProcessorKwargs._defaults.get("images_kwargs", {})
  212. images_kwargs.update(kwargs)
  213. merge_size = images_kwargs.get("merge_size", None) or self.image_processor.merge_size
  214. num_image_patches = [
  215. self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
  216. for image_size in image_sizes
  217. ]
  218. num_image_tokens = [(num_patches // merge_size**2) for num_patches in num_image_patches]
  219. vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
  220. if video_sizes is not None:
  221. videos_kwargs = Glm4vProcessorKwargs._defaults.get("videos_kwargs", {})
  222. videos_kwargs.update(kwargs)
  223. num_video_patches = [
  224. self.video_processor.get_number_of_video_patches(*video_size, videos_kwargs)
  225. for video_size in video_sizes
  226. ]
  227. num_video_tokens = [(num_patches // merge_size**2) for num_patches in num_video_patches]
  228. vision_data["num_video_tokens"] = num_video_tokens
  229. return MultiModalData(**vision_data)
  230. def post_process_image_text_to_text(
  231. self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs
  232. ):
  233. """
  234. Post-process the output of the model to decode the text.
  235. Args:
  236. generated_outputs (`torch.Tensor` or `np.ndarray`):
  237. The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)`
  238. or `(sequence_length,)`.
  239. skip_special_tokens (`bool`, *optional*, defaults to `True`):
  240. Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method.
  241. clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`):
  242. Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method.
  243. **kwargs:
  244. Additional arguments to be passed to the tokenizer's `batch_decode method`.
  245. Returns:
  246. `list[str]`: The decoded text.
  247. """
  248. return self.tokenizer.batch_decode(
  249. generated_outputs,
  250. skip_special_tokens=skip_special_tokens,
  251. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  252. **kwargs,
  253. )
  254. __all__ = ["Glm4vProcessor"]