modeling_llava.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. # coding=utf-8
  2. # Copyright 2023 the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Llava model."""
  16. from dataclasses import dataclass
  17. from typing import Optional, Union
  18. import torch
  19. from torch import nn
  20. from ...activations import ACT2FN
  21. from ...cache_utils import Cache
  22. from ...generation import GenerationMixin
  23. from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
  24. from ...modeling_utils import PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  27. from ..auto import AutoModel
  28. from .configuration_llava import LlavaConfig
  29. logger = logging.get_logger(__name__)
  30. @dataclass
  31. @auto_docstring(
  32. custom_intro="""
  33. Base class for Llava outputs, with hidden states and attentions.
  34. """
  35. )
  36. class LlavaModelOutputWithPast(BaseModelOutputWithPast):
  37. r"""
  38. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  39. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  40. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  41. `past_key_values` input) to speed up sequential decoding.
  42. image_hidden_states (`torch.FloatTensor`, *optional*):
  43. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  44. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  45. """
  46. image_hidden_states: Optional[torch.FloatTensor] = None
  47. @dataclass
  48. @auto_docstring(
  49. custom_intro="""
  50. Base class for Llava causal language model (or autoregressive) outputs.
  51. """
  52. )
  53. class LlavaCausalLMOutputWithPast(ModelOutput):
  54. r"""
  55. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  56. Language modeling loss (for next-token prediction).
  57. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  58. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  59. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  60. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  61. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  62. `past_key_values` input) to speed up sequential decoding.
  63. image_hidden_states (`torch.FloatTensor`, *optional*):
  64. A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
  65. image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
  66. """
  67. loss: Optional[torch.FloatTensor] = None
  68. logits: Optional[torch.FloatTensor] = None
  69. past_key_values: Optional[Cache] = None
  70. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  71. attentions: Optional[tuple[torch.FloatTensor]] = None
  72. image_hidden_states: Optional[torch.FloatTensor] = None
  73. class LlavaMultiModalProjector(nn.Module):
  74. def __init__(self, config: LlavaConfig):
  75. super().__init__()
  76. # We have hidden_size * the number of vision feature layers
  77. num_feature_layers = 1 if isinstance(config.vision_feature_layer, int) else len(config.vision_feature_layer)
  78. self.linear_1 = nn.Linear(
  79. config.vision_config.hidden_size * num_feature_layers,
  80. config.text_config.hidden_size,
  81. bias=config.multimodal_projector_bias,
  82. )
  83. self.act = ACT2FN[config.projector_hidden_act]
  84. self.linear_2 = nn.Linear(
  85. config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
  86. )
  87. def forward(self, image_features):
  88. hidden_states = self.linear_1(image_features)
  89. hidden_states = self.act(hidden_states)
  90. hidden_states = self.linear_2(hidden_states)
  91. return hidden_states
  92. @auto_docstring
  93. class LlavaPreTrainedModel(PreTrainedModel):
  94. config: LlavaConfig
  95. base_model_prefix = ""
  96. supports_gradient_checkpointing = True
  97. _skip_keys_device_placement = "past_key_values"
  98. _supports_flash_attn = True
  99. _supports_sdpa = True
  100. _can_compile_fullgraph = True
  101. _supports_flex_attn = True
  102. _supports_attention_backend = True
  103. @auto_docstring(
  104. custom_intro="""
  105. The Llava model which consists of a vision backbone and a language model, without a language modeling head.
  106. """
  107. )
  108. class LlavaModel(LlavaPreTrainedModel):
  109. _checkpoint_conversion_mapping = {"language_model.model": "language_model"}
  110. def __init__(self, config: LlavaConfig):
  111. super().__init__(config)
  112. self.vision_tower = AutoModel.from_config(config.vision_config)
  113. self.multi_modal_projector = LlavaMultiModalProjector(config)
  114. self.language_model = AutoModel.from_config(config.text_config)
  115. self.post_init()
  116. def get_input_embeddings(self):
  117. return self.language_model.get_input_embeddings()
  118. def set_input_embeddings(self, value):
  119. self.language_model.set_input_embeddings(value)
  120. def set_decoder(self, decoder):
  121. self.language_model = decoder
  122. def get_decoder(self):
  123. return self.language_model
  124. def get_image_features(
  125. self,
  126. pixel_values: torch.FloatTensor,
  127. vision_feature_layer: Optional[Union[int, list[int]]] = None,
  128. vision_feature_select_strategy: Optional[str] = None,
  129. **kwargs,
  130. ):
  131. """
  132. Obtains image last hidden states from the vision tower and apply multimodal projection.
  133. Args:
  134. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
  135. The tensors corresponding to the input images.
  136. vision_feature_layer (`Union[int, list[int]]`, *optional*):
  137. The index of the layer to select the vision feature. If multiple indices are provided,
  138. the vision feature of the corresponding indices will be concatenated to form the
  139. vision features.
  140. vision_feature_select_strategy (`str`, *optional*):
  141. The feature selection strategy used to select the vision feature from the vision backbone.
  142. Can be one of `"default"` or `"full"`
  143. Returns:
  144. image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
  145. """
  146. vision_feature_layer = (
  147. vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
  148. )
  149. vision_feature_select_strategy = (
  150. vision_feature_select_strategy
  151. if vision_feature_select_strategy is not None
  152. else self.config.vision_feature_select_strategy
  153. )
  154. if vision_feature_select_strategy not in ["default", "full"]:
  155. raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
  156. kwargs = {k: v for k, v in kwargs.items() if v is not None}
  157. # this is not memory efficient at all (output_hidden_states=True) will save all the hidden states.
  158. image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)
  159. # If we have one vision feature layer, return the corresponding hidden states,
  160. # otherwise, select the hidden states of each feature layer and concatenate them
  161. if isinstance(vision_feature_layer, int):
  162. selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
  163. if vision_feature_select_strategy == "default":
  164. selected_image_feature = selected_image_feature[:, 1:]
  165. else:
  166. hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
  167. # For default; crop CLS from each hidden state in the hidden state pool
  168. if vision_feature_select_strategy == "default":
  169. hs_pool = [hs[:, 1:] for hs in hs_pool]
  170. selected_image_feature = torch.cat(hs_pool, dim=-1)
  171. image_features = self.multi_modal_projector(selected_image_feature)
  172. if "image_sizes" in kwargs:
  173. split_sizes = [
  174. (height // self.vision_tower.patch_size) * (width // self.vision_tower.patch_size)
  175. for height, width in kwargs["image_sizes"]
  176. ]
  177. image_features = torch.split(image_features.squeeze(0), split_sizes)
  178. else:
  179. image_features = list(image_features)
  180. return image_features
  181. def get_placeholder_mask(
  182. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  183. ):
  184. """
  185. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  186. equal to the length of multimodal features. If the lengths are different, an error is raised.
  187. """
  188. if input_ids is None:
  189. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  190. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  191. )
  192. special_image_mask = special_image_mask.all(-1)
  193. else:
  194. special_image_mask = input_ids == self.config.image_token_id
  195. n_image_tokens = special_image_mask.sum()
  196. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  197. n_image_features = image_features.shape[0] * image_features.shape[1]
  198. if inputs_embeds[special_image_mask].numel() != image_features.numel():
  199. raise ValueError(
  200. f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
  201. )
  202. return special_image_mask
  203. @can_return_tuple
  204. @auto_docstring
  205. def forward(
  206. self,
  207. input_ids: Optional[torch.LongTensor] = None,
  208. pixel_values: Optional[torch.FloatTensor] = None,
  209. attention_mask: Optional[torch.Tensor] = None,
  210. position_ids: Optional[torch.LongTensor] = None,
  211. past_key_values: Optional[Cache] = None,
  212. inputs_embeds: Optional[torch.FloatTensor] = None,
  213. vision_feature_layer: Optional[Union[int, list[int]]] = None,
  214. vision_feature_select_strategy: Optional[str] = None,
  215. cache_position: Optional[torch.LongTensor] = None,
  216. image_sizes: Optional[torch.Tensor] = None,
  217. **kwargs: Unpack[TransformersKwargs],
  218. ) -> Union[tuple, LlavaModelOutputWithPast]:
  219. vision_feature_layer = (
  220. vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
  221. )
  222. vision_feature_select_strategy = (
  223. vision_feature_select_strategy
  224. if vision_feature_select_strategy is not None
  225. else self.config.vision_feature_select_strategy
  226. )
  227. if (input_ids is None) ^ (inputs_embeds is not None):
  228. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  229. if inputs_embeds is None:
  230. inputs_embeds = self.get_input_embeddings()(input_ids)
  231. if pixel_values is not None:
  232. image_features = self.get_image_features(
  233. pixel_values=pixel_values,
  234. vision_feature_layer=vision_feature_layer,
  235. vision_feature_select_strategy=vision_feature_select_strategy,
  236. image_sizes=image_sizes,
  237. )
  238. image_features = torch.cat(image_features, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
  239. special_image_mask = self.get_placeholder_mask(
  240. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  241. )
  242. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  243. outputs = self.language_model(
  244. attention_mask=attention_mask,
  245. position_ids=position_ids,
  246. past_key_values=past_key_values,
  247. inputs_embeds=inputs_embeds,
  248. cache_position=cache_position,
  249. **kwargs,
  250. )
  251. return LlavaModelOutputWithPast(
  252. last_hidden_state=outputs.last_hidden_state,
  253. past_key_values=outputs.past_key_values,
  254. hidden_states=outputs.hidden_states,
  255. attentions=outputs.attentions,
  256. image_hidden_states=image_features if pixel_values is not None else None,
  257. )
  258. @auto_docstring(
  259. custom_intro="""
  260. The LLAVA model which consists of a vision backbone and a language model.
  261. """
  262. )
  263. class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin):
  264. _checkpoint_conversion_mapping = {
  265. "^language_model.model": "model.language_model",
  266. "^vision_tower": "model.vision_tower",
  267. "^multi_modal_projector": "model.multi_modal_projector",
  268. "^language_model.lm_head": "lm_head",
  269. }
  270. _tied_weights_keys = ["lm_head.weight"]
  271. def __init__(self, config: LlavaConfig):
  272. super().__init__(config)
  273. self.model = LlavaModel(config)
  274. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  275. self.post_init()
  276. def get_input_embeddings(self):
  277. return self.model.get_input_embeddings()
  278. def set_input_embeddings(self, value):
  279. self.model.set_input_embeddings(value)
  280. def get_output_embeddings(self) -> nn.Module:
  281. return self.lm_head
  282. def set_decoder(self, decoder):
  283. self.model.set_decoder(decoder)
  284. def get_decoder(self):
  285. return self.model.get_decoder()
  286. def get_image_features(
  287. self,
  288. pixel_values: torch.FloatTensor,
  289. vision_feature_layer: Optional[Union[int, list[int]]] = None,
  290. vision_feature_select_strategy: Optional[str] = None,
  291. **kwargs,
  292. ):
  293. return self.model.get_image_features(
  294. pixel_values=pixel_values,
  295. vision_feature_layer=vision_feature_layer,
  296. vision_feature_select_strategy=vision_feature_select_strategy,
  297. **kwargs,
  298. )
  299. # Make modules available through conditional class for BC
  300. @property
  301. def language_model(self):
  302. return self.model.language_model
  303. @property
  304. def vision_tower(self):
  305. return self.model.vision_tower
  306. @property
  307. def multi_modal_projector(self):
  308. return self.model.multi_modal_projector
  309. @can_return_tuple
  310. @auto_docstring
  311. def forward(
  312. self,
  313. input_ids: Optional[torch.LongTensor] = None,
  314. pixel_values: Optional[torch.FloatTensor] = None,
  315. attention_mask: Optional[torch.Tensor] = None,
  316. position_ids: Optional[torch.LongTensor] = None,
  317. past_key_values: Optional[Cache] = None,
  318. inputs_embeds: Optional[torch.FloatTensor] = None,
  319. vision_feature_layer: Optional[Union[int, list[int]]] = None,
  320. vision_feature_select_strategy: Optional[str] = None,
  321. labels: Optional[torch.LongTensor] = None,
  322. cache_position: Optional[torch.LongTensor] = None,
  323. logits_to_keep: Union[int, torch.Tensor] = 0,
  324. image_sizes: Optional[torch.Tensor] = None,
  325. **kwargs: Unpack[TransformersKwargs],
  326. ) -> Union[tuple, LlavaCausalLMOutputWithPast]:
  327. r"""
  328. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  329. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  330. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  331. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  332. Example:
  333. ```python
  334. >>> from PIL import Image
  335. >>> import requests
  336. >>> from transformers import AutoProcessor, LlavaForConditionalGeneration
  337. >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf")
  338. >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
  339. >>> prompt = "USER: <image>\nWhat's the content of the image? ASSISTANT:"
  340. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  341. >>> image = Image.open(requests.get(url, stream=True).raw)
  342. >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
  343. >>> # Generate
  344. >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
  345. >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  346. "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed"
  347. ```"""
  348. vision_feature_layer = (
  349. vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
  350. )
  351. vision_feature_select_strategy = (
  352. vision_feature_select_strategy
  353. if vision_feature_select_strategy is not None
  354. else self.config.vision_feature_select_strategy
  355. )
  356. outputs = self.model(
  357. input_ids=input_ids,
  358. pixel_values=pixel_values,
  359. attention_mask=attention_mask,
  360. position_ids=position_ids,
  361. past_key_values=past_key_values,
  362. inputs_embeds=inputs_embeds,
  363. vision_feature_layer=vision_feature_layer,
  364. vision_feature_select_strategy=vision_feature_select_strategy,
  365. cache_position=cache_position,
  366. image_sizes=image_sizes,
  367. **kwargs,
  368. )
  369. hidden_states = outputs[0]
  370. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  371. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  372. logits = self.lm_head(hidden_states[:, slice_indices, :])
  373. loss = None
  374. if labels is not None:
  375. loss = self.loss_function(
  376. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  377. )
  378. return LlavaCausalLMOutputWithPast(
  379. loss=loss,
  380. logits=logits,
  381. past_key_values=outputs.past_key_values,
  382. hidden_states=outputs.hidden_states,
  383. attentions=outputs.attentions,
  384. image_hidden_states=outputs.image_hidden_states,
  385. )
  386. def prepare_inputs_for_generation(
  387. self,
  388. input_ids,
  389. past_key_values=None,
  390. inputs_embeds=None,
  391. pixel_values=None,
  392. attention_mask=None,
  393. cache_position=None,
  394. logits_to_keep=None,
  395. **kwargs,
  396. ):
  397. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  398. model_inputs = super().prepare_inputs_for_generation(
  399. input_ids,
  400. past_key_values=past_key_values,
  401. inputs_embeds=inputs_embeds,
  402. attention_mask=attention_mask,
  403. cache_position=cache_position,
  404. logits_to_keep=logits_to_keep,
  405. **kwargs,
  406. )
  407. if cache_position[0] == 0:
  408. # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
  409. # Otherwise we need pixel values to be passed to model
  410. model_inputs["pixel_values"] = pixel_values
  411. return model_inputs
  412. __all__ = ["LlavaForConditionalGeneration", "LlavaPreTrainedModel", "LlavaModel"]