modeling_vipllava.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/vipllava/modular_vipllava.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_vipllava.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2023 the HuggingFace Inc. team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. from dataclasses import dataclass
  22. from typing import Optional, Union
  23. import torch
  24. from torch import nn
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache
  27. from ...generation import GenerationMixin
  28. from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
  29. from ...modeling_utils import PreTrainedModel
  30. from ...utils import auto_docstring, can_return_tuple
  31. from ..auto import AutoModel
  32. from .configuration_vipllava import VipLlavaConfig
  33. @dataclass
  34. @auto_docstring(
  35. custom_intro="""
  36. Base class for VipLlava outputs, with hidden states and attentions.
  37. """
  38. )
  39. class VipLlavaModelOutputWithPast(BaseModelOutputWithPast):
  40. r"""
  41. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  42. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  43. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  44. `past_key_values` input) to speed up sequential decoding.
  45. image_hidden_states (`torch.FloatTensor`, *optional*):
  46. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  47. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  48. """
  49. image_hidden_states: Optional[torch.FloatTensor] = None
  50. @dataclass
  51. @auto_docstring(
  52. custom_intro="""
  53. Base class for VipLlava causal language model (or autoregressive) outputs.
  54. """
  55. )
  56. class VipLlavaCausalLMOutputWithPast(ModelOutput):
  57. r"""
  58. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  59. Language modeling loss (for next-token prediction).
  60. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  61. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  62. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  63. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  64. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  65. `past_key_values` input) to speed up sequential decoding.
  66. image_hidden_states (`torch.FloatTensor`, *optional*):
  67. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  68. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  69. """
  70. loss: Optional[torch.FloatTensor] = None
  71. logits: Optional[torch.FloatTensor] = None
  72. past_key_values: Optional[Cache] = None
  73. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  74. attentions: Optional[tuple[torch.FloatTensor]] = None
  75. image_hidden_states: Optional[torch.FloatTensor] = None
  76. class VipLlavaMultiModalProjector(nn.Module):
  77. def __init__(self, config: VipLlavaConfig):
  78. super().__init__()
  79. num_feature_layers = 1 if isinstance(config.vision_feature_layers, int) else len(config.vision_feature_layers)
  80. self.projector_layernorm = nn.LayerNorm(
  81. num_feature_layers * config.vision_config.hidden_size, eps=config.projector_layernorm_eps
  82. )
  83. self.linear_1 = nn.Linear(
  84. num_feature_layers * config.vision_config.hidden_size,
  85. config.text_config.hidden_size,
  86. bias=True,
  87. )
  88. self.act = ACT2FN[config.projector_hidden_act]
  89. self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
  90. def forward(self, hidden_states):
  91. hidden_states = self.projector_layernorm(hidden_states)
  92. hidden_states = self.linear_1(hidden_states)
  93. hidden_states = self.act(hidden_states)
  94. hidden_states = self.linear_2(hidden_states)
  95. return hidden_states
  96. @auto_docstring
  97. class VipLlavaPreTrainedModel(PreTrainedModel):
  98. config: VipLlavaConfig
  99. base_model_prefix = ""
  100. supports_gradient_checkpointing = True
  101. _skip_keys_device_placement = "past_key_values"
  102. _supports_flash_attn = True
  103. _supports_sdpa = True
  104. _can_compile_fullgraph = True
  105. _supports_flex_attn = True
  106. _supports_attention_backend = True
  107. @auto_docstring(
  108. custom_intro="""
  109. The VipLlava model which consists of a vision backbone and a language model, without a language modeling head.
  110. """
  111. )
  112. class VipLlavaModel(VipLlavaPreTrainedModel):
  113. _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
  114. def __init__(self, config: VipLlavaConfig):
  115. super().__init__(config)
  116. self.vision_tower = AutoModel.from_config(config.vision_config)
  117. self.multi_modal_projector = VipLlavaMultiModalProjector(config)
  118. self.language_model = AutoModel.from_config(config.text_config)
  119. self.post_init()
  120. def get_input_embeddings(self):
  121. return self.language_model.get_input_embeddings()
  122. def set_input_embeddings(self, value):
  123. self.language_model.set_input_embeddings(value)
  124. def set_decoder(self, decoder):
  125. self.language_model = decoder
  126. def get_decoder(self):
  127. return self.language_model
  128. def get_image_features(
  129. self, pixel_values: torch.FloatTensor, vision_feature_layers: Optional[Union[int, list[int]]] = None
  130. ):
  131. """
  132. Obtains image last hidden states from the vision tower and apply multimodal projection.
  133. Args:
  134. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
  135. The tensors corresponding to the input images.
  136. vision_feature_layers (`Union[int, list[int]]`):
  137. The vision feature layer, or the list of indexes of the layers to select
  138. the vision feature.
  139. Returns:
  140. image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
  141. """
  142. vision_feature_layers = (
  143. vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
  144. )
  145. image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
  146. # If multiple feature layers are provided (which is usually the case)
  147. # then the image features are concatenated after the CLS is removed.
  148. if isinstance(vision_feature_layers, int):
  149. image_features = image_outputs.hidden_states[vision_feature_layers][:, 1:]
  150. else:
  151. # Usually, we select the features from index 1: the layers -2, -5, -8, -11 and 6
  152. image_features = [image_outputs.hidden_states[index][:, 1:] for index in vision_feature_layers]
  153. image_features = torch.cat(image_features, dim=-1)
  154. image_features = self.multi_modal_projector(image_features)
  155. return image_features
  156. def get_placeholder_mask(
  157. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  158. ):
  159. """
  160. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  161. equal to the length of multimodal features. If the lengths are different, an error is raised.
  162. """
  163. if input_ids is None:
  164. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  165. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  166. )
  167. special_image_mask = special_image_mask.all(-1)
  168. else:
  169. special_image_mask = input_ids == self.config.image_token_id
  170. n_image_tokens = special_image_mask.sum()
  171. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  172. n_image_features = image_features.shape[0] * image_features.shape[1]
  173. if inputs_embeds[special_image_mask].numel() != image_features.numel():
  174. raise ValueError(
  175. f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
  176. )
  177. return special_image_mask
  178. @auto_docstring
  179. def forward(
  180. self,
  181. input_ids: Optional[torch.LongTensor] = None,
  182. pixel_values: Optional[torch.FloatTensor] = None,
  183. attention_mask: Optional[torch.Tensor] = None,
  184. position_ids: Optional[torch.LongTensor] = None,
  185. past_key_values: Optional[Cache] = None,
  186. inputs_embeds: Optional[torch.FloatTensor] = None,
  187. vision_feature_layers: Optional[Union[int, list[int]]] = None,
  188. use_cache: Optional[bool] = None,
  189. output_attentions: Optional[bool] = None,
  190. output_hidden_states: Optional[bool] = None,
  191. return_dict: Optional[bool] = None,
  192. cache_position: Optional[torch.LongTensor] = None,
  193. **lm_kwargs,
  194. ) -> Union[tuple, VipLlavaModelOutputWithPast]:
  195. r"""
  196. vision_feature_layers (`Union[int, list[int]]`, *optional*):
  197. The vision feature layer, or the list of indexes of the layers to select
  198. the vision feature.
  199. """
  200. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  201. output_hidden_states = (
  202. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  203. )
  204. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  205. vision_feature_layers = (
  206. vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
  207. )
  208. if (input_ids is None) ^ (inputs_embeds is not None):
  209. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  210. if inputs_embeds is None:
  211. inputs_embeds = self.get_input_embeddings()(input_ids)
  212. if pixel_values is not None:
  213. image_features = self.get_image_features(
  214. pixel_values=pixel_values, vision_feature_layers=vision_feature_layers
  215. )
  216. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  217. special_image_mask = self.get_placeholder_mask(
  218. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  219. )
  220. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  221. outputs = self.language_model(
  222. attention_mask=attention_mask,
  223. position_ids=position_ids,
  224. past_key_values=past_key_values,
  225. inputs_embeds=inputs_embeds,
  226. use_cache=use_cache,
  227. output_attentions=output_attentions,
  228. output_hidden_states=output_hidden_states,
  229. return_dict=True,
  230. cache_position=cache_position,
  231. **lm_kwargs,
  232. )
  233. output = VipLlavaModelOutputWithPast(
  234. last_hidden_state=outputs.last_hidden_state,
  235. past_key_values=outputs.past_key_values,
  236. hidden_states=outputs.hidden_states,
  237. attentions=outputs.attentions,
  238. image_hidden_states=image_features if pixel_values is not None else None,
  239. )
  240. return output if return_dict else output.to_tuple()
  241. @auto_docstring(
  242. custom_intro="""
  243. The VIPLLAVA model which consists of a vision backbone and a language model.
  244. """
  245. )
  246. class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin):
  247. _checkpoint_conversion_mapping = {
  248. "^language_model.model": "model.language_model",
  249. "^vision_tower": "model.vision_tower",
  250. "^multi_modal_projector": "model.multi_modal_projector",
  251. "^language_model.lm_head": "lm_head",
  252. }
  253. _tied_weights_keys = ["lm_head.weight"]
  254. def __init__(self, config: VipLlavaConfig):
  255. super().__init__(config)
  256. self.model = VipLlavaModel(config)
  257. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  258. self.post_init()
  259. def get_input_embeddings(self):
  260. return self.model.get_input_embeddings()
  261. def set_input_embeddings(self, value):
  262. self.model.set_input_embeddings(value)
  263. def get_output_embeddings(self) -> nn.Module:
  264. return self.lm_head
  265. def set_decoder(self, decoder):
  266. self.model.set_decoder(decoder)
  267. def get_decoder(self):
  268. return self.model.get_decoder()
  269. def get_image_features(
  270. self, pixel_values: torch.FloatTensor, vision_feature_layers: Optional[Union[int, list[int]]] = None
  271. ):
  272. return self.model.get_image_features(pixel_values=pixel_values, vision_feature_layers=vision_feature_layers)
  273. # Make modules available through conditional class for BC
  274. @property
  275. def language_model(self):
  276. return self.model.language_model
  277. @property
  278. def vision_tower(self):
  279. return self.model.vision_tower
  280. @property
  281. def multi_modal_projector(self):
  282. return self.model.multi_modal_projector
  283. @can_return_tuple
  284. @auto_docstring
  285. def forward(
  286. self,
  287. input_ids: Optional[torch.LongTensor] = None,
  288. pixel_values: Optional[torch.FloatTensor] = None,
  289. attention_mask: Optional[torch.Tensor] = None,
  290. position_ids: Optional[torch.LongTensor] = None,
  291. past_key_values: Optional[Cache] = None,
  292. inputs_embeds: Optional[torch.FloatTensor] = None,
  293. vision_feature_layers: Optional[Union[int, list[int]]] = None,
  294. labels: Optional[torch.LongTensor] = None,
  295. use_cache: Optional[bool] = None,
  296. output_attentions: Optional[bool] = None,
  297. output_hidden_states: Optional[bool] = None,
  298. return_dict: Optional[bool] = None,
  299. cache_position: Optional[torch.LongTensor] = None,
  300. logits_to_keep: Union[int, torch.Tensor] = 0,
  301. **lm_kwargs,
  302. ) -> Union[tuple, VipLlavaCausalLMOutputWithPast]:
  303. r"""
  304. vision_feature_layers (`Union[int, list[int]]`, *optional*):
  305. The vision feature layer, or the list of indexes of the layers to select
  306. the vision feature.
  307. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  308. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  309. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  310. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  311. Example:
  312. ```python
  313. >>> import torch
  314. >>> from PIL import Image
  315. >>> import requests
  316. >>> from transformers import AutoProcessor, VipLlavaForConditionalGeneration
  317. >>> model = VipLlavaForConditionalGeneration.from_pretrained("llava-hf/vip-llava-7b-hf", device_map="auto", dtype=torch.float16)
  318. >>> processor = AutoProcessor.from_pretrained("llava-hf/vip-llava-7b-hf")
  319. >>> prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.###Human: <image>\n{}###Assistant:"
  320. >>> question = "Can you please describe this image?"
  321. >>> prompt = prompt.format(question)
  322. >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/compel-neg.png"
  323. >>> image = Image.open(requests.get(url, stream=True).raw)
  324. >>> inputs = processor(text=text, images=image, return_tensors="pt").to(0, torch.float16)
  325. >>> # Generate
  326. >>> generate_ids = model.generate(**inputs, max_new_tokens=20)
  327. >>> processor.decode(generate_ids[0][len(inputs["input_ids"][0]):], skip_special_tokens=True)
  328. The image features a brown and white cat sitting on a green surface, with a red ball in its
  329. ```"""
  330. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  331. output_hidden_states = (
  332. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  333. )
  334. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  335. vision_feature_layers = (
  336. vision_feature_layers if vision_feature_layers is not None else self.config.vision_feature_layers
  337. )
  338. outputs = self.model(
  339. input_ids=input_ids,
  340. pixel_values=pixel_values,
  341. attention_mask=attention_mask,
  342. position_ids=position_ids,
  343. past_key_values=past_key_values,
  344. inputs_embeds=inputs_embeds,
  345. use_cache=use_cache,
  346. vision_feature_layers=vision_feature_layers,
  347. output_attentions=output_attentions,
  348. output_hidden_states=output_hidden_states,
  349. return_dict=True,
  350. cache_position=cache_position,
  351. **lm_kwargs,
  352. )
  353. hidden_states = outputs[0]
  354. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  355. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  356. logits = self.lm_head(hidden_states[:, slice_indices, :])
  357. loss = None
  358. if labels is not None:
  359. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
  360. return VipLlavaCausalLMOutputWithPast(
  361. loss=loss,
  362. logits=logits,
  363. past_key_values=outputs.past_key_values,
  364. hidden_states=outputs.hidden_states,
  365. attentions=outputs.attentions,
  366. image_hidden_states=outputs.image_hidden_states,
  367. )
  368. def prepare_inputs_for_generation(
  369. self,
  370. input_ids,
  371. past_key_values=None,
  372. inputs_embeds=None,
  373. pixel_values=None,
  374. attention_mask=None,
  375. cache_position=None,
  376. logits_to_keep=None,
  377. **kwargs,
  378. ):
  379. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  380. model_inputs = super().prepare_inputs_for_generation(
  381. input_ids,
  382. past_key_values=past_key_values,
  383. inputs_embeds=inputs_embeds,
  384. attention_mask=attention_mask,
  385. cache_position=cache_position,
  386. logits_to_keep=logits_to_keep,
  387. **kwargs,
  388. )
  389. if cache_position[0] == 0:
  390. # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
  391. # Otherwise we need pixel values to be passed to model
  392. model_inputs["pixel_values"] = pixel_values
  393. return model_inputs
  394. __all__ = ["VipLlavaModel", "VipLlavaForConditionalGeneration", "VipLlavaPreTrainedModel"]