modeling_kosmos2.py 80 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846
  1. # coding=utf-8
  2. # Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch KOSMOS-2 model."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Any, Callable, Optional, Union
  19. import torch
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  23. from ...generation import GenerationMixin
  24. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  25. from ...modeling_layers import GradientCheckpointingLayer
  26. from ...modeling_outputs import (
  27. BaseModelOutput,
  28. BaseModelOutputWithPastAndCrossAttentions,
  29. BaseModelOutputWithPooling,
  30. CausalLMOutputWithCrossAttentions,
  31. )
  32. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  33. from ...processing_utils import Unpack
  34. from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
  35. from ...utils.deprecation import deprecate_kwarg
  36. from .configuration_kosmos2 import Kosmos2Config, Kosmos2TextConfig, Kosmos2VisionConfig
  37. logger = logging.get_logger(__name__)
  38. def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
  39. """
  40. Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
  41. """
  42. bsz, src_len = mask.size()
  43. tgt_len = tgt_len if tgt_len is not None else src_len
  44. expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
  45. inverted_mask = 1.0 - expanded_mask
  46. return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
  47. def _make_causal_mask(
  48. input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
  49. ):
  50. """
  51. Make causal mask used for bi-directional self-attention.
  52. """
  53. bsz, tgt_len = input_ids_shape
  54. mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
  55. mask_cond = torch.arange(mask.size(-1), device=device)
  56. mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
  57. mask = mask.to(dtype)
  58. if past_key_values_length > 0:
  59. mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
  60. return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
  61. # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
  62. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  63. """
  64. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  65. are ignored. This is modified from fairseq's `utils.make_positions`.
  66. Args:
  67. x: torch.Tensor x:
  68. Returns: torch.Tensor
  69. """
  70. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  71. mask = input_ids.ne(padding_idx).int()
  72. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  73. return incremental_indices.long() + padding_idx
  74. @dataclass
  75. @auto_docstring(
  76. custom_intro="""
  77. Base class for text model's outputs that also contains a pooling of the last hidden states.
  78. """
  79. )
  80. class Kosmos2ModelOutput(ModelOutput):
  81. r"""
  82. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  83. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  84. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  85. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  86. input) to speed up sequential decoding.
  87. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
  88. Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
  89. projection_attentions (`tuple(torch.FloatTensor)`, *optional*):
  90. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  91. sequence_length)`.
  92. Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute
  93. the weighted average in the self-attention heads.
  94. vision_model_output (`BaseModelOutputWithPooling`, *optional*):
  95. The output of the [`Kosmos2VisionModel`].
  96. """
  97. last_hidden_state: Optional[torch.FloatTensor] = None
  98. past_key_values: Optional[Cache] = None
  99. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  100. attentions: Optional[tuple[torch.FloatTensor]] = None
  101. image_embeds: Optional[torch.FloatTensor] = None
  102. projection_attentions: Optional[tuple[torch.FloatTensor]] = None
  103. vision_model_output: BaseModelOutputWithPooling = None
  104. def to_tuple(self) -> tuple[Any]:
  105. return tuple(
  106. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  107. for k in self.keys()
  108. )
  109. @dataclass
  110. @auto_docstring(
  111. custom_intro="""
  112. Model output class for `Kosmos2ForConditionalGeneration`.
  113. """
  114. )
  115. class Kosmos2ForConditionalGenerationModelOutput(ModelOutput):
  116. r"""
  117. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  118. Language modeling loss (for next-token prediction).
  119. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  120. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  121. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  122. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  123. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  124. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  125. input) to speed up sequential decoding.
  126. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
  127. Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
  128. projection_attentions (`tuple(torch.FloatTensor)`, *optional*):
  129. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  130. sequence_length)`.
  131. Attentions weights given by `Kosmos2ImageToTextProjection`, after the attention softmax, used to compute
  132. the weighted average in the self-attention heads.
  133. vision_model_output (`BaseModelOutputWithPooling`, *optional*):
  134. The output of the [`Kosmos2VisionModel`].
  135. """
  136. loss: Optional[torch.FloatTensor] = None
  137. logits: Optional[torch.FloatTensor] = None
  138. past_key_values: Optional[Cache] = None
  139. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  140. attentions: Optional[tuple[torch.FloatTensor]] = None
  141. image_embeds: Optional[torch.FloatTensor] = None
  142. projection_attentions: Optional[tuple[torch.FloatTensor]] = None
  143. vision_model_output: BaseModelOutputWithPooling = None
  144. def to_tuple(self) -> tuple[Any]:
  145. return tuple(
  146. self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
  147. for k in self.keys()
  148. )
  149. # Copied from transformers.models.clip.modeling_clip.CLIPVisionEmbeddings with CLIP->Kosmos2
  150. class Kosmos2VisionEmbeddings(nn.Module):
  151. def __init__(self, config: Kosmos2VisionConfig):
  152. super().__init__()
  153. self.config = config
  154. self.embed_dim = config.hidden_size
  155. self.image_size = config.image_size
  156. self.patch_size = config.patch_size
  157. self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
  158. self.patch_embedding = nn.Conv2d(
  159. in_channels=config.num_channels,
  160. out_channels=self.embed_dim,
  161. kernel_size=self.patch_size,
  162. stride=self.patch_size,
  163. bias=False,
  164. )
  165. self.num_patches = (self.image_size // self.patch_size) ** 2
  166. self.num_positions = self.num_patches + 1
  167. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  168. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  169. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  170. """
  171. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  172. images. This method is also adapted to support torch.jit tracing.
  173. Adapted from:
  174. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  175. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  176. """
  177. num_patches = embeddings.shape[1] - 1
  178. position_embedding = self.position_embedding.weight.unsqueeze(0)
  179. num_positions = position_embedding.shape[1] - 1
  180. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  181. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  182. return self.position_embedding(self.position_ids)
  183. class_pos_embed = position_embedding[:, :1]
  184. patch_pos_embed = position_embedding[:, 1:]
  185. dim = embeddings.shape[-1]
  186. new_height = height // self.patch_size
  187. new_width = width // self.patch_size
  188. sqrt_num_positions = torch_int(num_positions**0.5)
  189. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  190. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  191. patch_pos_embed = nn.functional.interpolate(
  192. patch_pos_embed,
  193. size=(new_height, new_width),
  194. mode="bicubic",
  195. align_corners=False,
  196. )
  197. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  198. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  199. def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
  200. batch_size, _, height, width = pixel_values.shape
  201. if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
  202. raise ValueError(
  203. f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})."
  204. )
  205. target_dtype = self.patch_embedding.weight.dtype
  206. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  207. patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
  208. class_embeds = self.class_embedding.expand(batch_size, 1, -1)
  209. embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
  210. if interpolate_pos_encoding:
  211. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  212. else:
  213. embeddings = embeddings + self.position_embedding(self.position_ids)
  214. return embeddings
  215. # Adapted from transformers.models.siglip.modeling_siglip.eager_attention_forward -> Kosmos2 doesn't cast attn weights to fp32
  216. def eager_attention_forward(
  217. module: nn.Module,
  218. query: torch.Tensor,
  219. key: torch.Tensor,
  220. value: torch.Tensor,
  221. attention_mask: Optional[torch.Tensor],
  222. scaling: float,
  223. dropout: float = 0.0,
  224. **kwargs,
  225. ):
  226. attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
  227. if attention_mask is not None:
  228. attn_weights = attn_weights + attention_mask
  229. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  230. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  231. attn_output = torch.matmul(attn_weights, value)
  232. attn_output = attn_output.transpose(1, 2).contiguous()
  233. return attn_output, attn_weights
  234. class Kosmos2VisionAttention(nn.Module):
  235. """Multi-headed attention from 'Attention Is All You Need' paper"""
  236. def __init__(self, config):
  237. super().__init__()
  238. self.config = config
  239. self.embed_dim = config.hidden_size
  240. self.num_heads = config.num_attention_heads
  241. self.head_dim = self.embed_dim // self.num_heads
  242. if self.head_dim * self.num_heads != self.embed_dim:
  243. raise ValueError(
  244. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  245. f" {self.num_heads})."
  246. )
  247. self.scale = self.head_dim**-0.5
  248. self.dropout = config.attention_dropout
  249. self.is_causal = False
  250. self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
  251. self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
  252. self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
  253. self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
  254. def forward(
  255. self,
  256. hidden_states: torch.Tensor,
  257. attention_mask: Optional[torch.Tensor] = None,
  258. causal_attention_mask: Optional[torch.Tensor] = None,
  259. output_attentions: Optional[bool] = False,
  260. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  261. """Input shape: Batch x Time x Channel"""
  262. batch_size, seq_length, embed_dim = hidden_states.shape
  263. queries = self.q_proj(hidden_states)
  264. keys = self.k_proj(hidden_states)
  265. values = self.v_proj(hidden_states)
  266. queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  267. keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  268. values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  269. # CLIP text model uses both `causal_attention_mask` and `attention_mask`
  270. # in case FA2 kernel is called, `is_causal` should be inferred from `causal_attention_mask`
  271. if self.config._attn_implementation != "flash_attention_2":
  272. if attention_mask is not None and causal_attention_mask is not None:
  273. attention_mask = attention_mask + causal_attention_mask
  274. elif causal_attention_mask is not None:
  275. attention_mask = causal_attention_mask
  276. else:
  277. self.is_causal = causal_attention_mask is not None
  278. attention_interface: Callable = eager_attention_forward
  279. if self.config._attn_implementation != "eager":
  280. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  281. attn_output, attn_weights = attention_interface(
  282. self,
  283. queries,
  284. keys,
  285. values,
  286. attention_mask,
  287. is_causal=self.is_causal,
  288. scaling=self.scale,
  289. dropout=0.0 if not self.training else self.dropout,
  290. )
  291. attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
  292. attn_output = self.out_proj(attn_output)
  293. if not output_attentions:
  294. attn_weights = None
  295. return attn_output, attn_weights
  296. # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Kosmos2Vision
  297. class Kosmos2VisionMLP(nn.Module):
  298. def __init__(self, config):
  299. super().__init__()
  300. self.config = config
  301. self.activation_fn = ACT2FN[config.hidden_act]
  302. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  303. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  304. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  305. hidden_states = self.fc1(hidden_states)
  306. hidden_states = self.activation_fn(hidden_states)
  307. hidden_states = self.fc2(hidden_states)
  308. return hidden_states
  309. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoderLayer with AltCLIP->Kosmos2Vision
  310. class Kosmos2VisionEncoderLayer(GradientCheckpointingLayer):
  311. def __init__(self, config: Kosmos2VisionConfig):
  312. super().__init__()
  313. self.embed_dim = config.hidden_size
  314. self.self_attn = Kosmos2VisionAttention(config)
  315. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  316. self.mlp = Kosmos2VisionMLP(config)
  317. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  318. def forward(
  319. self,
  320. hidden_states: torch.Tensor,
  321. attention_mask: torch.Tensor,
  322. causal_attention_mask: torch.Tensor,
  323. output_attentions: Optional[bool] = False,
  324. ) -> tuple[torch.FloatTensor]:
  325. """
  326. Args:
  327. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  328. attention_mask (`torch.FloatTensor`): attention mask of size
  329. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
  330. `(config.encoder_attention_heads,)`.
  331. output_attentions (`bool`, *optional*):
  332. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  333. returned tensors for more detail.
  334. """
  335. residual = hidden_states
  336. hidden_states = self.layer_norm1(hidden_states)
  337. hidden_states, attn_weights = self.self_attn(
  338. hidden_states=hidden_states,
  339. attention_mask=attention_mask,
  340. causal_attention_mask=causal_attention_mask,
  341. output_attentions=output_attentions,
  342. )
  343. hidden_states = residual + hidden_states
  344. residual = hidden_states
  345. hidden_states = self.layer_norm2(hidden_states)
  346. hidden_states = self.mlp(hidden_states)
  347. hidden_states = residual + hidden_states
  348. outputs = (hidden_states,)
  349. if output_attentions:
  350. outputs += (attn_weights,)
  351. return outputs
  352. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Kosmos2Vision
  353. class Kosmos2VisionEncoder(nn.Module):
  354. """
  355. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  356. [`Kosmos2VisionEncoderLayer`].
  357. Args:
  358. config: Kosmos2VisionConfig
  359. """
  360. def __init__(self, config: Kosmos2VisionConfig):
  361. super().__init__()
  362. self.config = config
  363. self.layers = nn.ModuleList([Kosmos2VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  364. self.gradient_checkpointing = False
  365. @can_return_tuple
  366. def forward(
  367. self,
  368. inputs_embeds,
  369. attention_mask: Optional[torch.Tensor] = None,
  370. causal_attention_mask: Optional[torch.Tensor] = None,
  371. output_attentions: Optional[bool] = None,
  372. output_hidden_states: Optional[bool] = None,
  373. return_dict: Optional[bool] = None,
  374. ) -> Union[tuple, BaseModelOutput]:
  375. r"""
  376. Args:
  377. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  378. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
  379. This is useful if you want more control over how to convert `input_ids` indices into associated vectors
  380. than the model's internal embedding lookup matrix.
  381. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  382. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  383. - 1 for tokens that are **not masked**,
  384. - 0 for tokens that are **masked**.
  385. [What are attention masks?](../glossary#attention-mask)
  386. causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  387. Causal mask for the text model. Mask values selected in `[0, 1]`:
  388. - 1 for tokens that are **not masked**,
  389. - 0 for tokens that are **masked**.
  390. [What are attention masks?](../glossary#attention-mask)
  391. output_attentions (`bool`, *optional*):
  392. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  393. returned tensors for more detail.
  394. output_hidden_states (`bool`, *optional*):
  395. Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
  396. for more detail.
  397. return_dict (`bool`, *optional*):
  398. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
  399. """
  400. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  401. output_hidden_states = (
  402. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  403. )
  404. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  405. encoder_states = () if output_hidden_states else None
  406. all_attentions = () if output_attentions else None
  407. hidden_states = inputs_embeds
  408. for idx, encoder_layer in enumerate(self.layers):
  409. if output_hidden_states:
  410. encoder_states = encoder_states + (hidden_states,)
  411. layer_outputs = encoder_layer(
  412. hidden_states,
  413. attention_mask,
  414. causal_attention_mask,
  415. output_attentions=output_attentions,
  416. )
  417. hidden_states = layer_outputs[0]
  418. if output_attentions:
  419. all_attentions = all_attentions + (layer_outputs[1],)
  420. if output_hidden_states:
  421. encoder_states = encoder_states + (hidden_states,)
  422. return BaseModelOutput(
  423. last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
  424. )
  425. # Similar to `transformers.models.clip.modeling_clip.CLIPVisionTransformer` but without docstring for `forward`
  426. class Kosmos2VisionTransformer(nn.Module):
  427. # Copied from transformers.models.altclip.modeling_altclip.AltCLIPVisionTransformer.__init__ with AltCLIPVision->Kosmos2Vision,ALTCLIP_VISION->KOSMOS2_VISION,AltCLIP->Kosmos2Vision
  428. def __init__(self, config: Kosmos2VisionConfig):
  429. super().__init__()
  430. self.config = config
  431. embed_dim = config.hidden_size
  432. self.embeddings = Kosmos2VisionEmbeddings(config)
  433. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  434. self.encoder = Kosmos2VisionEncoder(config)
  435. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  436. def forward(
  437. self,
  438. pixel_values: Optional[torch.FloatTensor] = None,
  439. output_attentions: Optional[bool] = None,
  440. output_hidden_states: Optional[bool] = None,
  441. interpolate_pos_encoding: bool = False,
  442. return_dict: Optional[bool] = None,
  443. ) -> Union[tuple, BaseModelOutputWithPooling]:
  444. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  445. output_hidden_states = (
  446. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  447. )
  448. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  449. if pixel_values is None:
  450. raise ValueError("You have to specify pixel_values")
  451. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  452. hidden_states = self.pre_layrnorm(hidden_states)
  453. encoder_outputs = self.encoder(
  454. inputs_embeds=hidden_states,
  455. output_attentions=output_attentions,
  456. output_hidden_states=output_hidden_states,
  457. return_dict=return_dict,
  458. )
  459. last_hidden_state = encoder_outputs[0]
  460. pooled_output = last_hidden_state[:, 0, :]
  461. pooled_output = self.post_layernorm(pooled_output)
  462. if not return_dict:
  463. return (last_hidden_state, pooled_output) + encoder_outputs[1:]
  464. return BaseModelOutputWithPooling(
  465. last_hidden_state=last_hidden_state,
  466. pooler_output=pooled_output,
  467. hidden_states=encoder_outputs.hidden_states,
  468. attentions=encoder_outputs.attentions,
  469. )
  470. # Similar to `transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding` but allowing to pass `position_ids`
  471. class Kosmos2TextSinusoidalPositionalEmbedding(nn.Module):
  472. """This module produces sinusoidal positional embeddings of any length."""
  473. # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.__init__
  474. def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None):
  475. super().__init__()
  476. self.offset = 2
  477. self.embedding_dim = embedding_dim
  478. self.padding_idx = padding_idx
  479. self.make_weights(num_positions + self.offset, embedding_dim, padding_idx)
  480. # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.make_weights
  481. def make_weights(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
  482. emb_weights = self.get_embedding(num_embeddings, embedding_dim, padding_idx)
  483. if hasattr(self, "weights"):
  484. # in forward put the weights on the correct dtype and device of the param
  485. emb_weights = emb_weights.to(dtype=self.weights.dtype, device=self.weights.device)
  486. self.register_buffer("weights", emb_weights, persistent=False)
  487. @staticmethod
  488. # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.get_embedding
  489. def get_embedding(num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None):
  490. """
  491. Build sinusoidal embeddings.
  492. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of
  493. "Attention Is All You Need".
  494. """
  495. half_dim = embedding_dim // 2
  496. emb = math.log(10000) / (half_dim - 1)
  497. emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
  498. emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
  499. emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
  500. if embedding_dim % 2 == 1:
  501. # zero pad
  502. emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
  503. if padding_idx is not None:
  504. emb[padding_idx, :] = 0
  505. return emb.to(torch.get_default_dtype())
  506. @torch.no_grad()
  507. def forward(
  508. self,
  509. input_ids: Optional[torch.Tensor] = None,
  510. inputs_embeds: Optional[torch.Tensor] = None,
  511. past_key_values_length: int = 0,
  512. position_ids: Optional[torch.Tensor] = None,
  513. ):
  514. if input_ids is not None:
  515. bsz, seq_len = input_ids.size()
  516. if position_ids is None:
  517. # Create the position ids from the input token ids. Any padded tokens remain padded.
  518. position_ids = create_position_ids_from_input_ids(
  519. input_ids, self.padding_idx, past_key_values_length
  520. ).to(input_ids.device)
  521. else:
  522. bsz, seq_len = inputs_embeds.size()[:-1]
  523. if position_ids is None:
  524. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, past_key_values_length)
  525. # expand embeddings if needed
  526. max_pos = self.padding_idx + 1 + seq_len + past_key_values_length
  527. if max_pos > self.weights.size(0):
  528. self.make_weights(max_pos + self.offset, self.embedding_dim, self.padding_idx)
  529. return self.weights.index_select(0, position_ids.view(-1)).view(bsz, seq_len, self.weights.shape[-1]).detach()
  530. # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding.create_position_ids_from_inputs_embeds
  531. def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_length):
  532. """
  533. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  534. Args:
  535. inputs_embeds: torch.Tensor
  536. Returns: torch.Tensor
  537. """
  538. input_shape = inputs_embeds.size()[:-1]
  539. sequence_length = input_shape[1]
  540. position_ids = torch.arange(
  541. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  542. )
  543. return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
  544. class KosmosTextAttention(nn.Module):
  545. """Multi-headed attention from 'Attention Is All You Need' paper"""
  546. # Similar to transformers.models.bart.modeling_bart.BartAttention.__init__ except an additional `inner_attn_ln`.
  547. def __init__(
  548. self,
  549. config,
  550. embed_dim: int,
  551. num_heads: int,
  552. dropout: float = 0.0,
  553. is_decoder: Optional[bool] = False,
  554. add_inner_attn_layernorm: Optional[bool] = False,
  555. bias: Optional[bool] = True,
  556. layer_idx: Optional[bool] = None,
  557. ):
  558. super().__init__()
  559. self.config = config
  560. self.embed_dim = embed_dim
  561. self.num_heads = num_heads
  562. self.dropout = dropout
  563. self.head_dim = embed_dim // num_heads
  564. if (self.head_dim * num_heads) != self.embed_dim:
  565. raise ValueError(
  566. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
  567. f" and `num_heads`: {num_heads})."
  568. )
  569. self.scaling = self.head_dim**-0.5
  570. self.is_decoder = is_decoder
  571. self.layer_idx = layer_idx
  572. self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  573. self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  574. self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  575. self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
  576. # End opy
  577. self.inner_attn_ln = None
  578. if add_inner_attn_layernorm:
  579. self.inner_attn_ln = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  580. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  581. def forward(
  582. self,
  583. hidden_states: torch.Tensor,
  584. encoder_hidden_states: Optional[torch.Tensor] = None,
  585. past_key_values: Optional[Cache] = None,
  586. attention_mask: Optional[torch.Tensor] = None,
  587. layer_head_mask: Optional[torch.Tensor] = None,
  588. output_attentions: bool = False,
  589. cache_position: Optional[torch.Tensor] = None,
  590. **kwargs,
  591. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
  592. """Input shape: Batch x Time x Channel"""
  593. # if key_value_states are provided this layer is used as a cross-attention layer
  594. # for the decoder
  595. is_cross_attention = encoder_hidden_states is not None
  596. batch_size, seq_length = hidden_states.shape[:2]
  597. query_states = self.q_proj(hidden_states)
  598. query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
  599. is_updated = False
  600. if past_key_values is not None:
  601. if isinstance(past_key_values, EncoderDecoderCache):
  602. is_updated = past_key_values.is_updated.get(self.layer_idx)
  603. if is_cross_attention:
  604. # after the first generated id, we can subsequently re-use all key/value_states from cache
  605. curr_past_key_value = past_key_values.cross_attention_cache
  606. else:
  607. curr_past_key_value = past_key_values.self_attention_cache
  608. else:
  609. curr_past_key_value = past_key_values
  610. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  611. if is_cross_attention and past_key_values is not None and is_updated:
  612. # reuse k,v, cross_attentions
  613. key_states = curr_past_key_value.layers[self.layer_idx].keys
  614. value_states = curr_past_key_value.layers[self.layer_idx].values
  615. else:
  616. key_states = self.k_proj(current_states)
  617. value_states = self.v_proj(current_states)
  618. key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
  619. value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
  620. if past_key_values is not None:
  621. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  622. cache_position = cache_position if not is_cross_attention else None
  623. key_states, value_states = curr_past_key_value.update(
  624. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  625. )
  626. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  627. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  628. past_key_values.is_updated[self.layer_idx] = True
  629. attention_interface: Callable = eager_attention_forward
  630. if self.config._attn_implementation != "eager":
  631. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  632. attn_output, attn_weights = attention_interface(
  633. self,
  634. query_states,
  635. key_states,
  636. value_states,
  637. attention_mask,
  638. dropout=0.0 if not self.training else self.dropout,
  639. scaling=self.scaling,
  640. **kwargs,
  641. )
  642. attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
  643. if self.inner_attn_ln is not None:
  644. attn_output = self.inner_attn_ln(attn_output)
  645. attn_output = self.out_proj(attn_output)
  646. return attn_output, attn_weights
  647. class Kosmos2TextFFN(nn.Module):
  648. def __init__(self, config: Kosmos2TextConfig):
  649. super().__init__()
  650. self.dropout = config.dropout
  651. self.activation_fn = ACT2FN[config.activation_function]
  652. self.activation_dropout = config.activation_dropout
  653. self.fc1 = nn.Linear(config.embed_dim, config.ffn_dim)
  654. self.fc2 = nn.Linear(config.ffn_dim, config.embed_dim)
  655. self.ffn_layernorm = nn.LayerNorm(config.ffn_dim, eps=config.layer_norm_eps)
  656. def forward(self, hidden_states):
  657. hidden_states = self.activation_fn(self.fc1(hidden_states))
  658. hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
  659. hidden_states = self.ffn_layernorm(hidden_states)
  660. hidden_states = self.fc2(hidden_states)
  661. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  662. return hidden_states
  663. class Kosmos2TextBlock(GradientCheckpointingLayer):
  664. def __init__(self, config: Kosmos2TextConfig, layer_idx=None):
  665. super().__init__()
  666. self.embed_dim = config.embed_dim
  667. self.self_attn = KosmosTextAttention(
  668. config,
  669. embed_dim=self.embed_dim,
  670. num_heads=config.attention_heads,
  671. dropout=config.attention_dropout,
  672. is_decoder=True,
  673. add_inner_attn_layernorm=True,
  674. layer_idx=layer_idx,
  675. )
  676. self.dropout = config.dropout
  677. self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  678. if config.add_cross_attention:
  679. self.encoder_attn = KosmosTextAttention(
  680. config,
  681. embed_dim=self.embed_dim,
  682. num_heads=config.attention_heads,
  683. dropout=config.attention_dropout,
  684. is_decoder=True,
  685. add_inner_attn_layernorm=False,
  686. layer_idx=layer_idx,
  687. )
  688. self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  689. self.ffn = Kosmos2TextFFN(config)
  690. self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  691. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  692. def forward(
  693. self,
  694. hidden_states: torch.Tensor,
  695. attention_mask: Optional[torch.Tensor] = None,
  696. encoder_hidden_states: Optional[torch.Tensor] = None,
  697. encoder_attention_mask: Optional[torch.Tensor] = None,
  698. layer_head_mask: Optional[torch.Tensor] = None,
  699. cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
  700. past_key_values: Optional[Cache] = None,
  701. output_attentions: Optional[bool] = False,
  702. use_cache: Optional[bool] = True,
  703. cache_position: Optional[torch.Tensor] = None,
  704. **kwargs,
  705. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  706. residual = hidden_states
  707. hidden_states = self.self_attn_layer_norm(hidden_states)
  708. hidden_states, self_attn_weights = self.self_attn(
  709. hidden_states=hidden_states,
  710. past_key_values=past_key_values,
  711. attention_mask=attention_mask,
  712. layer_head_mask=layer_head_mask,
  713. output_attentions=output_attentions,
  714. cache_position=cache_position,
  715. **kwargs,
  716. )
  717. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  718. hidden_states = residual + hidden_states
  719. # Cross-Attention Block
  720. cross_attn_weights = None
  721. if encoder_hidden_states is not None:
  722. if not hasattr(self, "encoder_attn"):
  723. raise ValueError(
  724. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  725. " by setting `config.add_cross_attention=True`"
  726. )
  727. residual = hidden_states
  728. hidden_states = self.encoder_attn_layer_norm(hidden_states)
  729. hidden_states, cross_attn_weights = self.encoder_attn(
  730. hidden_states=hidden_states,
  731. encoder_hidden_states=encoder_hidden_states,
  732. attention_mask=encoder_attention_mask,
  733. layer_head_mask=cross_attn_layer_head_mask,
  734. past_key_values=past_key_values,
  735. output_attentions=output_attentions,
  736. cache_position=cache_position,
  737. **kwargs,
  738. )
  739. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  740. hidden_states = residual + hidden_states
  741. # Fully Connected
  742. residual = hidden_states
  743. hidden_states = self.final_layer_norm(hidden_states)
  744. # FFN
  745. hidden_states = self.ffn(hidden_states)
  746. hidden_states = residual + hidden_states
  747. outputs = (hidden_states,)
  748. if output_attentions:
  749. outputs += (self_attn_weights, cross_attn_weights)
  750. return outputs
  751. class Kosmos2TextTransformer(nn.Module):
  752. """
  753. Transformer decoder consisting of `config.layers` layers. Each layer is a [`Kosmos2TextBlock`].
  754. Args:
  755. config: Kosmos2TextConfig
  756. """
  757. def __init__(self, config: Kosmos2TextConfig):
  758. super().__init__()
  759. self.config = config
  760. self.dropout = config.dropout
  761. self.layerdrop = config.layerdrop
  762. self.embed_scale = math.sqrt(config.embed_dim) if config.scale_embedding else 1.0
  763. self.embed_tokens = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.pad_token_id)
  764. self.embed_positions = Kosmos2TextSinusoidalPositionalEmbedding(
  765. num_positions=config.max_position_embeddings,
  766. embedding_dim=config.embed_dim,
  767. padding_idx=config.pad_token_id,
  768. )
  769. self.layers = nn.ModuleList([Kosmos2TextBlock(config, layer_idx=i) for i in range(config.layers)])
  770. self.layer_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps)
  771. self.gradient_checkpointing = False
  772. def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
  773. # create causal mask
  774. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  775. combined_attention_mask = None
  776. if input_shape[-1] > 1:
  777. combined_attention_mask = _make_causal_mask(
  778. input_shape,
  779. inputs_embeds.dtype,
  780. device=inputs_embeds.device,
  781. past_key_values_length=past_key_values_length,
  782. )
  783. if attention_mask is not None:
  784. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  785. expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
  786. inputs_embeds.device
  787. )
  788. combined_attention_mask = (
  789. expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
  790. )
  791. return combined_attention_mask
  792. def forward_embedding(
  793. self,
  794. input_ids,
  795. inputs_embeds: Optional[torch.Tensor] = None,
  796. image_embeds: Optional[torch.Tensor] = None,
  797. img_input_mask: Optional[torch.Tensor] = None,
  798. past_key_values_length: int = 0,
  799. position_ids: Optional[torch.Tensor] = None,
  800. ):
  801. # The argument `inputs_embeds` should be the one without being multiplied by `self.embed_scale`.
  802. if inputs_embeds is None:
  803. inputs_embeds = self.embed_tokens(input_ids)
  804. if image_embeds is not None:
  805. inputs_embeds[img_input_mask.to(dtype=torch.bool)] = image_embeds.to(inputs_embeds.device).view(
  806. -1, image_embeds.size(-1)
  807. )
  808. inputs_embeds = inputs_embeds * self.embed_scale
  809. # embed positions
  810. positions = self.embed_positions(
  811. input_ids=input_ids,
  812. inputs_embeds=inputs_embeds,
  813. past_key_values_length=past_key_values_length,
  814. position_ids=position_ids,
  815. )
  816. positions = positions.to(inputs_embeds.device)
  817. hidden_states = inputs_embeds + positions
  818. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  819. return hidden_states
  820. def forward(
  821. self,
  822. input_ids: Optional[torch.Tensor] = None,
  823. attention_mask: Optional[torch.Tensor] = None,
  824. image_embeds: Optional[torch.Tensor] = None,
  825. image_embeds_position_mask: Optional[torch.Tensor] = None,
  826. encoder_hidden_states: Optional[torch.Tensor] = None,
  827. encoder_attention_mask: Optional[torch.Tensor] = None,
  828. head_mask: Optional[torch.Tensor] = None,
  829. cross_attn_head_mask: Optional[torch.Tensor] = None,
  830. past_key_values: Optional[Cache] = None,
  831. inputs_embeds: Optional[torch.Tensor] = None,
  832. position_ids: Optional[torch.Tensor] = None,
  833. use_cache: Optional[bool] = None,
  834. output_attentions: Optional[bool] = None,
  835. output_hidden_states: Optional[bool] = None,
  836. return_dict: Optional[bool] = None,
  837. cache_position: Optional[torch.Tensor] = None,
  838. **kwargs: Unpack[FlashAttentionKwargs],
  839. ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
  840. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  841. output_hidden_states = (
  842. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  843. )
  844. use_cache = use_cache if use_cache is not None else self.config.use_cache
  845. if input_ids is not None and inputs_embeds is not None:
  846. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  847. elif input_ids is not None:
  848. input_shape = input_ids.shape
  849. input_ids = input_ids.view(-1, input_shape[-1])
  850. elif inputs_embeds is not None:
  851. input_shape = inputs_embeds.size()[:-1]
  852. else:
  853. raise ValueError("You have to specify either input_ids or inputs_embeds")
  854. if self.gradient_checkpointing and self.training:
  855. if use_cache:
  856. logger.warning_once(
  857. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  858. )
  859. use_cache = False
  860. if use_cache and past_key_values is None:
  861. past_key_values = (
  862. EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  863. if encoder_hidden_states is not None
  864. else DynamicCache(config=self.config)
  865. )
  866. if use_cache and isinstance(past_key_values, tuple):
  867. logger.warning_once(
  868. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  869. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  870. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  871. )
  872. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  873. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  874. # We don't need img info. when `past_key_values_length` > 0
  875. if past_key_values_length > 0:
  876. image_embeds = None
  877. image_embeds_position_mask = None
  878. hidden_states = self.forward_embedding(
  879. input_ids=input_ids,
  880. inputs_embeds=inputs_embeds,
  881. image_embeds=image_embeds,
  882. img_input_mask=image_embeds_position_mask,
  883. past_key_values_length=past_key_values_length,
  884. position_ids=position_ids,
  885. )
  886. attention_mask = self._prepare_decoder_attention_mask(
  887. attention_mask, input_shape, hidden_states, past_key_values_length
  888. )
  889. # expand encoder attention mask
  890. if encoder_hidden_states is not None and encoder_attention_mask is not None:
  891. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
  892. encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
  893. hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
  894. # decoder layers
  895. all_hidden_states = () if output_hidden_states else None
  896. all_self_attns = () if output_attentions else None
  897. all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
  898. # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
  899. for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
  900. if attn_mask is not None:
  901. if attn_mask.size()[0] != (len(self.layers)):
  902. raise ValueError(
  903. f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
  904. f" {head_mask.size()[0]}."
  905. )
  906. for idx, decoder_layer in enumerate(self.layers):
  907. # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description)
  908. if output_hidden_states:
  909. all_hidden_states += (hidden_states,)
  910. if self.training:
  911. dropout_probability = torch.rand([])
  912. if dropout_probability < self.layerdrop:
  913. continue
  914. layer_outputs = decoder_layer(
  915. hidden_states,
  916. attention_mask,
  917. encoder_hidden_states,
  918. encoder_attention_mask=encoder_attention_mask,
  919. layer_head_mask=(head_mask[idx] if head_mask is not None else None),
  920. cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None),
  921. past_key_values=past_key_values,
  922. output_attentions=output_attentions,
  923. use_cache=use_cache,
  924. cache_position=cache_position,
  925. **kwargs,
  926. )
  927. hidden_states = layer_outputs[0]
  928. if output_attentions:
  929. all_self_attns += (layer_outputs[1],)
  930. if encoder_hidden_states is not None:
  931. all_cross_attentions += (layer_outputs[2],)
  932. # add final layer norm
  933. hidden_states = self.layer_norm(hidden_states)
  934. # add hidden states from the last decoder layer
  935. if output_hidden_states:
  936. all_hidden_states += (hidden_states,)
  937. return BaseModelOutputWithPastAndCrossAttentions(
  938. last_hidden_state=hidden_states,
  939. past_key_values=past_key_values,
  940. hidden_states=all_hidden_states,
  941. attentions=all_self_attns,
  942. cross_attentions=all_cross_attentions,
  943. )
  944. @auto_docstring
  945. class Kosmos2PreTrainedModel(PreTrainedModel):
  946. config: Kosmos2Config
  947. supports_gradient_checkpointing = True
  948. _no_split_modules = ["Kosmos2VisionEncoderLayer", "Kosmos2TextBlock"]
  949. _supports_attention_backend = True
  950. _supports_flash_attn = True
  951. _supports_sdpa = True
  952. def _init_weights(self, module: nn.Module):
  953. """Initialize the weights"""
  954. if isinstance(self, Kosmos2VisionModel):
  955. factor = self.config.initializer_factor
  956. elif isinstance(self, (Kosmos2Model, Kosmos2ForConditionalGeneration)):
  957. factor = self.config.vision_config.initializer_factor
  958. if isinstance(self, (Kosmos2TextModel, Kosmos2TextForCausalLM)):
  959. std = self.config.init_std
  960. elif isinstance(self, (Kosmos2Model, Kosmos2ForConditionalGeneration)):
  961. std = self.config.text_config.init_std
  962. if isinstance(module, Kosmos2VisionEmbeddings):
  963. nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
  964. nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
  965. nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
  966. elif isinstance(module, Kosmos2VisionAttention):
  967. in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  968. out_proj_std = (module.embed_dim**-0.5) * factor
  969. nn.init.normal_(module.q_proj.weight, std=in_proj_std)
  970. nn.init.normal_(module.k_proj.weight, std=in_proj_std)
  971. nn.init.normal_(module.v_proj.weight, std=in_proj_std)
  972. nn.init.normal_(module.out_proj.weight, std=out_proj_std)
  973. elif isinstance(module, Kosmos2VisionMLP):
  974. in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
  975. fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
  976. nn.init.normal_(module.fc1.weight, std=fc_std)
  977. nn.init.normal_(module.fc2.weight, std=in_proj_std)
  978. elif isinstance(module, KosmosTextAttention):
  979. nn.init.normal_(module.q_proj.weight, std=std)
  980. nn.init.normal_(module.k_proj.weight, std=std)
  981. nn.init.normal_(module.v_proj.weight, std=std)
  982. nn.init.normal_(module.out_proj.weight, std=std)
  983. elif isinstance(module, Kosmos2TextFFN):
  984. nn.init.normal_(module.fc1.weight, std=std)
  985. nn.init.normal_(module.fc2.weight, std=std)
  986. elif isinstance(module, Kosmos2TextForCausalLM):
  987. nn.init.normal_(module.lm_head.weight, std=std)
  988. elif isinstance(module, Kosmos2ImageToTextProjection):
  989. nn.init.normal_(module.dense.weight, std=std)
  990. nn.init.normal_(module.latent_query)
  991. elif isinstance(module, Kosmos2TextTransformer):
  992. module.embed_tokens.weight.data.normal_(mean=0.0, std=std)
  993. if module.embed_tokens.padding_idx is not None:
  994. module.embed_tokens.weight.data[module.embed_tokens.padding_idx].zero_()
  995. elif isinstance(module, nn.LayerNorm):
  996. module.weight.data.fill_(1.0)
  997. module.bias.data.zero_()
  998. if isinstance(module, nn.Linear) and module.bias is not None:
  999. module.bias.data.zero_()
  1000. class Kosmos2VisionModel(Kosmos2PreTrainedModel):
  1001. config: Kosmos2VisionConfig
  1002. main_input_name = "pixel_values"
  1003. # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.__init__ with CLIP_VISION->KOSMOS2_VISION,CLIP->Kosmos2,self.vision_model->self.model
  1004. def __init__(self, config: Kosmos2VisionConfig):
  1005. super().__init__(config)
  1006. self.model = Kosmos2VisionTransformer(config)
  1007. # Initialize weights and apply final processing
  1008. self.post_init()
  1009. # Copied from transformers.models.clip.modeling_clip.CLIPVisionModel.get_input_embeddings with CLIP_VISION->KOSMOS2_VISION,CLIP->Kosmos2,self.vision_model->self.model
  1010. def get_input_embeddings(self) -> nn.Module:
  1011. return self.model.embeddings.patch_embedding
  1012. @auto_docstring
  1013. def forward(
  1014. self,
  1015. pixel_values: Optional[torch.FloatTensor] = None,
  1016. output_attentions: Optional[bool] = None,
  1017. output_hidden_states: Optional[bool] = None,
  1018. interpolate_pos_encoding: bool = False,
  1019. return_dict: Optional[bool] = None,
  1020. ) -> Union[tuple, BaseModelOutputWithPooling]:
  1021. return self.model(
  1022. pixel_values=pixel_values,
  1023. output_attentions=output_attentions,
  1024. output_hidden_states=output_hidden_states,
  1025. interpolate_pos_encoding=interpolate_pos_encoding,
  1026. return_dict=return_dict,
  1027. )
  1028. class Kosmos2TextModel(Kosmos2PreTrainedModel):
  1029. config: Kosmos2TextConfig
  1030. def __init__(self, config: Kosmos2TextConfig):
  1031. super().__init__(config)
  1032. self.model = Kosmos2TextTransformer(config)
  1033. # Initialize weights and apply final processing
  1034. self.post_init()
  1035. def get_input_embeddings(self) -> nn.Module:
  1036. return self.model.embed_tokens
  1037. @can_return_tuple
  1038. @auto_docstring
  1039. def forward(
  1040. self,
  1041. input_ids: Optional[torch.Tensor] = None,
  1042. attention_mask: Optional[torch.Tensor] = None,
  1043. image_embeds: Optional[torch.Tensor] = None,
  1044. image_embeds_position_mask: Optional[torch.Tensor] = None,
  1045. encoder_hidden_states: Optional[torch.Tensor] = None,
  1046. encoder_attention_mask: Optional[torch.Tensor] = None,
  1047. head_mask: Optional[torch.Tensor] = None,
  1048. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1049. past_key_values: Optional[Cache] = None,
  1050. inputs_embeds: Optional[torch.Tensor] = None,
  1051. position_ids: Optional[torch.Tensor] = None,
  1052. use_cache: Optional[bool] = None,
  1053. output_attentions: Optional[bool] = None,
  1054. output_hidden_states: Optional[bool] = None,
  1055. return_dict: Optional[bool] = None,
  1056. cache_position: Optional[torch.Tensor] = None,
  1057. **kwargs: Unpack[FlashAttentionKwargs],
  1058. ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
  1059. r"""
  1060. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
  1061. Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
  1062. image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1063. Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0,
  1064. 1]`:
  1065. - 1 for places where to put the image features,
  1066. - 0 for places that are not for image features (i.e. for text tokens).
  1067. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1068. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1069. - 1 indicates the head is **not masked**,
  1070. - 0 indicates the head is **masked**.
  1071. """
  1072. return self.model(
  1073. input_ids=input_ids,
  1074. attention_mask=attention_mask,
  1075. image_embeds=image_embeds,
  1076. image_embeds_position_mask=image_embeds_position_mask,
  1077. encoder_hidden_states=encoder_hidden_states,
  1078. encoder_attention_mask=encoder_attention_mask,
  1079. head_mask=head_mask,
  1080. cross_attn_head_mask=cross_attn_head_mask,
  1081. past_key_values=past_key_values,
  1082. inputs_embeds=inputs_embeds,
  1083. position_ids=position_ids,
  1084. use_cache=use_cache,
  1085. output_attentions=output_attentions,
  1086. output_hidden_states=output_hidden_states,
  1087. return_dict=return_dict,
  1088. cache_position=cache_position,
  1089. **kwargs,
  1090. )
  1091. @auto_docstring(
  1092. custom_intro="""
  1093. The text model from KOSMOS-2 with a language modeling head on top (linear layer with weights tied to the input
  1094. embeddings).
  1095. """
  1096. )
  1097. class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin):
  1098. config: Kosmos2TextConfig
  1099. _tied_weights_keys = ["lm_head.weight"]
  1100. def __init__(self, config: Kosmos2TextConfig):
  1101. super().__init__(config)
  1102. self.model = Kosmos2TextTransformer(config)
  1103. self.lm_head = nn.Linear(in_features=config.embed_dim, out_features=config.vocab_size, bias=False)
  1104. # Initialize weights and apply final processing
  1105. self.post_init()
  1106. def get_input_embeddings(self) -> nn.Module:
  1107. return self.model.embed_tokens
  1108. def get_output_embeddings(self) -> nn.Module:
  1109. return self.lm_head
  1110. @can_return_tuple
  1111. @auto_docstring
  1112. def forward(
  1113. self,
  1114. input_ids: Optional[torch.Tensor] = None,
  1115. attention_mask: Optional[torch.Tensor] = None,
  1116. image_embeds: Optional[torch.Tensor] = None,
  1117. image_embeds_position_mask: Optional[torch.Tensor] = None,
  1118. encoder_hidden_states: Optional[torch.Tensor] = None,
  1119. encoder_attention_mask: Optional[torch.Tensor] = None,
  1120. head_mask: Optional[torch.Tensor] = None,
  1121. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1122. past_key_values: Optional[Cache] = None,
  1123. inputs_embeds: Optional[torch.Tensor] = None,
  1124. position_ids: Optional[torch.Tensor] = None,
  1125. labels: Optional[torch.LongTensor] = None,
  1126. use_cache: Optional[bool] = None,
  1127. output_attentions: Optional[bool] = None,
  1128. output_hidden_states: Optional[bool] = None,
  1129. return_dict: Optional[bool] = None,
  1130. cache_position: Optional[torch.Tensor] = None,
  1131. **kwargs: Unpack[TransformersKwargs],
  1132. ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
  1133. r"""
  1134. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
  1135. Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
  1136. image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1137. Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0,
  1138. 1]`:
  1139. - 1 for places where to put the image features,
  1140. - 0 for places that are not for image features (i.e. for text tokens).
  1141. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
  1142. Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
  1143. - 1 indicates the head is **not masked**,
  1144. - 0 indicates the head is **masked**.
  1145. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1146. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  1147. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  1148. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  1149. """
  1150. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1151. if labels is not None:
  1152. if use_cache:
  1153. logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.")
  1154. use_cache = False
  1155. outputs = self.model(
  1156. input_ids=input_ids,
  1157. attention_mask=attention_mask,
  1158. image_embeds=image_embeds,
  1159. image_embeds_position_mask=image_embeds_position_mask,
  1160. encoder_hidden_states=encoder_hidden_states,
  1161. encoder_attention_mask=encoder_attention_mask,
  1162. head_mask=head_mask,
  1163. cross_attn_head_mask=cross_attn_head_mask,
  1164. past_key_values=past_key_values,
  1165. inputs_embeds=inputs_embeds,
  1166. position_ids=position_ids,
  1167. use_cache=use_cache,
  1168. output_attentions=output_attentions,
  1169. output_hidden_states=output_hidden_states,
  1170. return_dict=True,
  1171. cache_position=cache_position,
  1172. **kwargs,
  1173. )
  1174. lm_logits = self.lm_head(outputs[0])
  1175. loss = None
  1176. if labels is not None:
  1177. loss = self.loss_function(logits=lm_logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  1178. return CausalLMOutputWithCrossAttentions(
  1179. loss=loss,
  1180. logits=lm_logits,
  1181. past_key_values=outputs.past_key_values,
  1182. hidden_states=outputs.hidden_states,
  1183. attentions=outputs.attentions,
  1184. cross_attentions=outputs.cross_attentions,
  1185. )
  1186. def prepare_inputs_for_generation(
  1187. self,
  1188. input_ids,
  1189. image_embeds=None,
  1190. image_embeds_position_mask=None,
  1191. past_key_values=None,
  1192. attention_mask=None,
  1193. inputs_embeds=None,
  1194. use_cache=None,
  1195. cache_position=None,
  1196. **model_kwargs,
  1197. ):
  1198. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  1199. # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
  1200. if cache_position[0] != 0:
  1201. image_embeds = None
  1202. image_embeds_position_mask = None
  1203. # appending `False` to `image_embeds_position_mask` (because `input_ids` grows during generation)
  1204. elif image_embeds_position_mask is not None:
  1205. batch_size, seq_len = inputs_embeds.size()[:-1] if inputs_embeds is not None else input_ids.size()
  1206. mask_len = image_embeds_position_mask.size()[-1]
  1207. image_embeds_position_mask = torch.cat(
  1208. (
  1209. image_embeds_position_mask,
  1210. torch.zeros(size=(batch_size, seq_len - mask_len), dtype=torch.bool, device=input_ids.device),
  1211. ),
  1212. dim=1,
  1213. )
  1214. model_inputs = super().prepare_inputs_for_generation(
  1215. input_ids,
  1216. past_key_values=past_key_values,
  1217. attention_mask=attention_mask,
  1218. image_embeds=image_embeds,
  1219. image_embeds_position_mask=image_embeds_position_mask,
  1220. inputs_embeds=inputs_embeds,
  1221. use_cache=use_cache,
  1222. cache_position=cache_position,
  1223. **model_kwargs,
  1224. )
  1225. # Kosmos2 has offset for position ids, so we need to create them correctly in PositionEmbedding layer
  1226. model_inputs.pop("position_ids", None)
  1227. return model_inputs
  1228. class Kosmos2ImageToTextProjection(nn.Module):
  1229. """The layer that transforms the image model's output to part of the text model's input (namely, image features)"""
  1230. def __init__(self, config: Kosmos2Config):
  1231. super().__init__()
  1232. self.dense = nn.Linear(config.vision_config.hidden_size, config.text_config.embed_dim)
  1233. self.latent_query = nn.Parameter(torch.randn(config.latent_query_num, config.text_config.embed_dim))
  1234. self.x_attn = KosmosTextAttention(
  1235. config.text_config,
  1236. config.text_config.embed_dim,
  1237. config.text_config.attention_heads,
  1238. dropout=config.text_config.attention_dropout,
  1239. is_decoder=False,
  1240. add_inner_attn_layernorm=False,
  1241. )
  1242. def forward(self, features):
  1243. hidden_states = self.dense(features)
  1244. # shape = [batch, latent_query_num, h_dim]
  1245. latent_query = self.latent_query.unsqueeze(0).expand(hidden_states.size(0), -1, -1)
  1246. key_value_states = torch.cat([hidden_states, latent_query], dim=1)
  1247. hidden_states, attn_weights = self.x_attn(
  1248. hidden_states=latent_query,
  1249. encoder_hidden_states=key_value_states,
  1250. past_key_values=None,
  1251. attention_mask=None,
  1252. output_attentions=None,
  1253. )
  1254. return hidden_states, attn_weights
  1255. @auto_docstring(
  1256. custom_intro="""
  1257. KOSMOS-2 Model for generating text and image features. The model consists of a vision encoder and a language model.
  1258. """
  1259. )
  1260. class Kosmos2Model(Kosmos2PreTrainedModel):
  1261. config: Kosmos2Config
  1262. main_input_name = "pixel_values"
  1263. def __init__(self, config: Kosmos2Config):
  1264. super().__init__(config)
  1265. self.text_model = Kosmos2TextModel(config.text_config)
  1266. self.vision_model = Kosmos2VisionModel(config.vision_config)
  1267. self.image_to_text_projection = Kosmos2ImageToTextProjection(config)
  1268. # Initialize weights and apply final processing
  1269. self.post_init()
  1270. def get_input_embeddings(self) -> nn.Module:
  1271. return self.text_model.model.embed_tokens
  1272. def set_input_embeddings(self, value):
  1273. self.text_model.model.embed_tokens = value
  1274. def get_image_features(
  1275. self,
  1276. pixel_values: torch.FloatTensor,
  1277. return_attentions: Optional[bool] = False,
  1278. interpolate_pos_encoding: Optional[bool] = False,
  1279. ):
  1280. """
  1281. Encodes images into continuous embeddings that can be forwarded to the language model.
  1282. Args:
  1283. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  1284. The tensors corresponding to the input images.
  1285. return_attentions (`bool`, *optional*, defaults to `False`):
  1286. Whether to return `projection_attentions` or not.
  1287. interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
  1288. Whether to interpolate positional embeddings or not.
  1289. """
  1290. vision_model_output = self.vision_model(
  1291. pixel_values=pixel_values,
  1292. interpolate_pos_encoding=interpolate_pos_encoding,
  1293. )
  1294. # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
  1295. image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0])
  1296. # normalized features
  1297. image_embeds = nn.functional.normalize(image_embeds, dim=-1)
  1298. image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)
  1299. if return_attentions:
  1300. return image_embeds, projection_attentions
  1301. return image_embeds
  1302. @can_return_tuple
  1303. @auto_docstring
  1304. def forward(
  1305. self,
  1306. pixel_values: Optional[torch.Tensor] = None,
  1307. input_ids: Optional[torch.Tensor] = None,
  1308. image_embeds_position_mask: Optional[torch.Tensor] = None,
  1309. attention_mask: Optional[torch.Tensor] = None,
  1310. head_mask: Optional[torch.Tensor] = None,
  1311. past_key_values: Optional[Cache] = None,
  1312. image_embeds: Optional[torch.Tensor] = None,
  1313. inputs_embeds: Optional[torch.Tensor] = None,
  1314. position_ids: Optional[torch.Tensor] = None,
  1315. use_cache: Optional[bool] = None,
  1316. output_attentions: Optional[bool] = None,
  1317. output_hidden_states: Optional[bool] = None,
  1318. interpolate_pos_encoding: bool = False,
  1319. return_dict: Optional[bool] = None,
  1320. **kwargs: Unpack[FlashAttentionKwargs],
  1321. ) -> Union[tuple, Kosmos2ModelOutput]:
  1322. r"""
  1323. image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1324. Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0,
  1325. 1]`:
  1326. - 1 for places where to put the image features,
  1327. - 0 for places that are not for image features (i.e. for text tokens).
  1328. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
  1329. Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
  1330. Examples:
  1331. ```python
  1332. >>> from PIL import Image
  1333. >>> import requests
  1334. >>> from transformers import AutoProcessor, Kosmos2Model
  1335. >>> model = Kosmos2Model.from_pretrained("microsoft/kosmos-2-patch14-224")
  1336. >>> processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
  1337. >>> url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
  1338. >>> image = Image.open(requests.get(url, stream=True).raw)
  1339. >>> text = (
  1340. ... "<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863>"
  1341. ... "</object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911>"
  1342. ... "</object>"
  1343. ... )
  1344. >>> inputs = processor(text=text, images=image, return_tensors="pt", add_eos_token=True)
  1345. >>> last_hidden_state = model(
  1346. ... pixel_values=inputs["pixel_values"],
  1347. ... input_ids=inputs["input_ids"],
  1348. ... attention_mask=inputs["attention_mask"],
  1349. ... image_embeds_position_mask=inputs["image_embeds_position_mask"],
  1350. ... ).last_hidden_state
  1351. >>> list(last_hidden_state.shape)
  1352. [1, 91, 2048]
  1353. ```"""
  1354. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1355. output_hidden_states = (
  1356. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1357. )
  1358. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1359. vision_model_output = None
  1360. projection_attentions = None
  1361. if image_embeds is None:
  1362. if pixel_values is None:
  1363. raise ValueError("You have to specify either `pixel_values` or `image_embeds`.")
  1364. image_embeds, projection_attentions = self.get_image_features(
  1365. pixel_values, return_attentions=True, interpolate_pos_encoding=interpolate_pos_encoding
  1366. )
  1367. outputs = self.text_model(
  1368. input_ids=input_ids,
  1369. attention_mask=attention_mask,
  1370. image_embeds=image_embeds,
  1371. image_embeds_position_mask=image_embeds_position_mask,
  1372. head_mask=head_mask,
  1373. past_key_values=past_key_values,
  1374. inputs_embeds=inputs_embeds,
  1375. position_ids=position_ids,
  1376. use_cache=use_cache,
  1377. output_attentions=output_attentions,
  1378. output_hidden_states=output_hidden_states,
  1379. return_dict=True,
  1380. **kwargs,
  1381. )
  1382. return Kosmos2ModelOutput(
  1383. last_hidden_state=outputs.last_hidden_state,
  1384. past_key_values=outputs.past_key_values,
  1385. hidden_states=outputs.hidden_states,
  1386. attentions=outputs.attentions,
  1387. image_embeds=image_embeds,
  1388. projection_attentions=projection_attentions,
  1389. vision_model_output=vision_model_output,
  1390. )
  1391. @auto_docstring(
  1392. custom_intro="""
  1393. KOSMOS-2 Model for generating text and bounding boxes given an image. The model consists of a vision encoder and a
  1394. language model.
  1395. """
  1396. )
  1397. class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin):
  1398. config: Kosmos2Config
  1399. main_input_name = "pixel_values"
  1400. _tied_weights_keys = ["text_model.lm_head.weight"]
  1401. def __init__(self, config: Kosmos2Config):
  1402. super().__init__(config)
  1403. self.text_model = Kosmos2TextForCausalLM(config.text_config)
  1404. self.vision_model = Kosmos2VisionModel(config.vision_config)
  1405. self.image_to_text_projection = Kosmos2ImageToTextProjection(config)
  1406. # Initialize weights and apply final processing
  1407. self.post_init()
  1408. def get_input_embeddings(self) -> nn.Module:
  1409. return self.text_model.model.embed_tokens
  1410. def set_input_embeddings(self, value):
  1411. self.text_model.model.embed_tokens = value
  1412. def get_output_embeddings(self) -> nn.Module:
  1413. return self.text_model.get_output_embeddings()
  1414. def set_output_embeddings(self, new_embeddings):
  1415. self.text_model.set_output_embeddings(new_embeddings)
  1416. @can_return_tuple
  1417. @auto_docstring
  1418. def forward(
  1419. self,
  1420. pixel_values: Optional[torch.Tensor] = None,
  1421. input_ids: Optional[torch.Tensor] = None,
  1422. image_embeds_position_mask: Optional[torch.Tensor] = None,
  1423. attention_mask: Optional[torch.Tensor] = None,
  1424. head_mask: Optional[torch.Tensor] = None,
  1425. past_key_values: Optional[Cache] = None,
  1426. image_embeds: Optional[torch.Tensor] = None,
  1427. inputs_embeds: Optional[torch.Tensor] = None,
  1428. position_ids: Optional[torch.Tensor] = None,
  1429. labels: Optional[torch.LongTensor] = None,
  1430. use_cache: Optional[bool] = None,
  1431. output_attentions: Optional[bool] = None,
  1432. output_hidden_states: Optional[bool] = None,
  1433. **kwargs: Unpack[TransformersKwargs],
  1434. ) -> Union[tuple, Kosmos2ForConditionalGenerationModelOutput]:
  1435. r"""
  1436. image_embeds_position_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1437. Mask to indicate the location in a sequence to insert the image features . Mask values selected in `[0,
  1438. 1]`:
  1439. - 1 for places where to put the image features,
  1440. - 0 for places that are not for image features (i.e. for text tokens).
  1441. image_embeds (`torch.FloatTensor` of shape `(batch_size, latent_query_num, hidden_size)`, *optional*):
  1442. Sequence of hidden-states at the output of `Kosmos2ImageToTextProjection`.
  1443. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1444. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  1445. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  1446. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  1447. Examples:
  1448. ```python
  1449. >>> from PIL import Image
  1450. >>> import requests
  1451. >>> from transformers import AutoProcessor, Kosmos2ForConditionalGeneration
  1452. >>> model = Kosmos2ForConditionalGeneration.from_pretrained("microsoft/kosmos-2-patch14-224")
  1453. >>> processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
  1454. >>> url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg"
  1455. >>> image = Image.open(requests.get(url, stream=True).raw)
  1456. >>> prompt = "<grounding> An image of"
  1457. >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
  1458. >>> generated_ids = model.generate(
  1459. ... pixel_values=inputs["pixel_values"],
  1460. ... input_ids=inputs["input_ids"],
  1461. ... attention_mask=inputs["attention_mask"],
  1462. ... image_embeds=None,
  1463. ... image_embeds_position_mask=inputs["image_embeds_position_mask"],
  1464. ... use_cache=True,
  1465. ... max_new_tokens=64,
  1466. ... )
  1467. >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  1468. >>> processed_text = processor.post_process_generation(generated_text, cleanup_and_extract=False)
  1469. >>> processed_text
  1470. '<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863></object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911></object>.'
  1471. >>> caption, entities = processor.post_process_generation(generated_text)
  1472. >>> caption
  1473. 'An image of a snowman warming himself by a fire.'
  1474. >>> entities
  1475. [('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])]
  1476. ```"""
  1477. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1478. output_hidden_states = (
  1479. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1480. )
  1481. vision_model_output = None
  1482. projection_attentions = None
  1483. if image_embeds is None:
  1484. if pixel_values is None:
  1485. raise ValueError("You have to specify either `pixel_values` or `image_embeds`.")
  1486. vision_model_output = self.vision_model(
  1487. pixel_values=pixel_values,
  1488. output_attentions=output_attentions,
  1489. output_hidden_states=output_hidden_states,
  1490. )
  1491. # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
  1492. image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0])
  1493. # normalized features
  1494. image_embeds = nn.functional.normalize(image_embeds, dim=-1)
  1495. image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)
  1496. lm_outputs = self.text_model(
  1497. input_ids=input_ids,
  1498. attention_mask=attention_mask,
  1499. image_embeds=image_embeds,
  1500. image_embeds_position_mask=image_embeds_position_mask,
  1501. head_mask=head_mask,
  1502. past_key_values=past_key_values,
  1503. inputs_embeds=inputs_embeds,
  1504. position_ids=position_ids,
  1505. labels=labels,
  1506. use_cache=use_cache,
  1507. output_attentions=output_attentions,
  1508. output_hidden_states=output_hidden_states,
  1509. return_dict=True,
  1510. **kwargs,
  1511. )
  1512. return Kosmos2ForConditionalGenerationModelOutput(
  1513. loss=lm_outputs.loss,
  1514. logits=lm_outputs.logits,
  1515. past_key_values=lm_outputs.past_key_values,
  1516. hidden_states=lm_outputs.hidden_states,
  1517. attentions=lm_outputs.attentions,
  1518. image_embeds=image_embeds,
  1519. projection_attentions=projection_attentions,
  1520. vision_model_output=vision_model_output,
  1521. )
  1522. @torch.no_grad()
  1523. def generate(
  1524. self,
  1525. pixel_values: Optional[torch.Tensor] = None,
  1526. image_embeds_position_mask: Optional[torch.Tensor] = None,
  1527. input_ids: Optional[torch.Tensor] = None,
  1528. attention_mask: Optional[torch.Tensor] = None,
  1529. image_embeds: Optional[torch.Tensor] = None,
  1530. inputs_embeds: Optional[torch.Tensor] = None,
  1531. **kwargs,
  1532. ):
  1533. # in order to allow `inputs` argument (as in `GenerationMixin`)
  1534. inputs = kwargs.pop("inputs", None)
  1535. if pixel_values is not None and inputs is not None:
  1536. raise ValueError(
  1537. f"`inputs`: {inputs} were passed alongside `pixel_values` which is not allowed."
  1538. f"Make sure to either pass `inputs` or pixel_values=..."
  1539. )
  1540. if pixel_values is None and inputs is not None:
  1541. pixel_values = inputs
  1542. if image_embeds is None:
  1543. vision_model_output = self.vision_model(pixel_values)
  1544. # The whole `last_hidden_state` through `post_layernorm` instead of just `pooled_output`.
  1545. image_embeds = self.vision_model.model.post_layernorm(vision_model_output[0])
  1546. # normalized features
  1547. image_embeds = nn.functional.normalize(image_embeds, dim=-1)
  1548. image_embeds, projection_attentions = self.image_to_text_projection(image_embeds)
  1549. output = self.text_model.generate(
  1550. input_ids=input_ids,
  1551. attention_mask=attention_mask,
  1552. image_embeds=image_embeds,
  1553. image_embeds_position_mask=image_embeds_position_mask,
  1554. inputs_embeds=inputs_embeds,
  1555. **kwargs,
  1556. )
  1557. return output
  1558. __all__ = ["Kosmos2ForConditionalGeneration", "Kosmos2Model", "Kosmos2PreTrainedModel"]