modeling_janus.py 60 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/janus/modular_janus.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_janus.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # coding=utf-8
  8. # Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
  9. #
  10. # Licensed under the Apache License, Version 2.0 (the "License");
  11. # you may not use this file except in compliance with the License.
  12. # You may obtain a copy of the License at
  13. #
  14. # http://www.apache.org/licenses/LICENSE-2.0
  15. #
  16. # Unless required by applicable law or agreed to in writing, software
  17. # distributed under the License is distributed on an "AS IS" BASIS,
  18. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  19. # See the License for the specific language governing permissions and
  20. # limitations under the License.
  21. import copy
  22. from dataclasses import dataclass
  23. from typing import Callable, Optional, Union
  24. import torch
  25. import torch.nn.functional as F
  26. from torch import nn
  27. from ...activations import ACT2FN
  28. from ...cache_utils import Cache
  29. from ...generation import ClassifierFreeGuidanceLogitsProcessor, GenerationMixin, GenerationMode, LogitsProcessorList
  30. from ...generation.utils import GenerateDecoderOnlyOutput
  31. from ...modeling_layers import GradientCheckpointingLayer
  32. from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ModelOutput
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
  36. from ...utils.generic import check_model_inputs
  37. from ..auto import AutoModel
  38. from .configuration_janus import JanusConfig, JanusVisionConfig, JanusVQVAEConfig
  39. logger = logging.get_logger(__name__)
  40. @auto_docstring
  41. class JanusPreTrainedModel(PreTrainedModel):
  42. config: JanusConfig
  43. base_model_prefix = "model"
  44. supports_gradient_checkpointing = True
  45. _no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"]
  46. _skip_keys_device_placement = ["past_key_values", "causal_mask"]
  47. _supports_flash_attn = True
  48. _supports_sdpa = True
  49. _can_compile_fullgraph = True
  50. _supports_param_buffer_assignment = False
  51. @dataclass
  52. @auto_docstring(
  53. custom_intro="""
  54. Base class for Janus VQ-VAE mode model outputs.
  55. """
  56. )
  57. class JanusVQVAEOutput(ModelOutput):
  58. r"""
  59. decoded_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  60. Reconstructed pixel values after encoding and decoding the input.
  61. embedding_loss (`torch.FloatTensor`):
  62. Embedding loss.
  63. """
  64. decoded_pixel_values: Optional[torch.FloatTensor] = None
  65. embedding_loss: Optional[torch.FloatTensor] = None
  66. @dataclass
  67. @auto_docstring(
  68. custom_intro="""
  69. Base class for Janus model's outputs that may also contain a past key/values (to speed up sequential decoding).
  70. """
  71. )
  72. class JanusBaseModelOutputWithPast(ModelOutput):
  73. r"""
  74. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  75. Sequence of hidden-states at the output of the last layer of the model.
  76. If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  77. hidden_size)` is output.
  78. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  79. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  80. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  81. `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  82. input) to speed up sequential decoding.
  83. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  84. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  85. sequence_length, hidden_size)`.
  86. image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
  87. """
  88. last_hidden_state: Optional[torch.FloatTensor] = None
  89. past_key_values: Optional[Cache] = None
  90. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  91. attentions: Optional[tuple[torch.FloatTensor]] = None
  92. image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  93. @dataclass
  94. @auto_docstring(
  95. custom_intro="""
  96. Base class for Janus causal language model (or autoregressive) outputs.
  97. """
  98. )
  99. class JanusCausalLMOutputWithPast(ModelOutput):
  100. r"""
  101. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  102. Language modeling loss (for next-token prediction).
  103. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  104. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  105. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  106. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  107. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  108. `past_key_values` input) to speed up sequential decoding.
  109. image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  110. Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
  111. sequence_length, hidden_size)`.
  112. image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
  113. """
  114. loss: Optional[torch.FloatTensor] = None
  115. logits: Optional[torch.FloatTensor] = None
  116. past_key_values: Optional[Cache] = None
  117. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  118. attentions: Optional[tuple[torch.FloatTensor]] = None
  119. image_hidden_states: Optional[tuple[torch.FloatTensor]] = None
  120. class JanusVisionEmbeddings(nn.Module):
  121. def __init__(self, config: JanusVisionConfig):
  122. super().__init__()
  123. self.config = config
  124. self.embed_dim = config.hidden_size
  125. self.image_size = config.image_size
  126. self.patch_size = config.patch_size
  127. self.patch_embedding = nn.Conv2d(
  128. in_channels=config.num_channels,
  129. out_channels=self.embed_dim,
  130. kernel_size=self.patch_size,
  131. stride=self.patch_size,
  132. padding="valid",
  133. )
  134. self.num_patches = (self.image_size // self.patch_size) ** 2
  135. self.num_positions = self.num_patches
  136. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  137. self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
  138. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  139. """
  140. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  141. images. This method is also adapted to support torch.jit tracing and no class embeddings.
  142. Adapted from:
  143. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  144. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  145. """
  146. num_patches = embeddings.shape[1]
  147. num_positions = self.position_embedding.weight.shape[0]
  148. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  149. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  150. return self.position_embedding(self.position_ids)
  151. patch_pos_embed = self.position_embedding.weight.unsqueeze(0)
  152. dim = embeddings.shape[-1]
  153. new_height = height // self.patch_size
  154. new_width = width // self.patch_size
  155. sqrt_num_positions = torch_int(num_positions**0.5)
  156. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  157. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  158. patch_pos_embed = nn.functional.interpolate(
  159. patch_pos_embed,
  160. size=(new_height, new_width),
  161. mode="bicubic",
  162. align_corners=False,
  163. )
  164. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  165. return patch_pos_embed
  166. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  167. _, _, height, width = pixel_values.shape
  168. target_dtype = self.patch_embedding.weight.dtype
  169. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  170. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  171. if interpolate_pos_encoding:
  172. pos_embeds = self.interpolate_pos_encoding(embeddings, height, width)
  173. else:
  174. pos_embeds = self.position_embedding(self.position_ids)
  175. embeddings = embeddings + pos_embeds
  176. return embeddings
  177. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  178. """
  179. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  180. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  181. """
  182. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  183. if n_rep == 1:
  184. return hidden_states
  185. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  186. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  187. def eager_attention_forward(
  188. module: nn.Module,
  189. query: torch.Tensor,
  190. key: torch.Tensor,
  191. value: torch.Tensor,
  192. attention_mask: Optional[torch.Tensor],
  193. scaling: float,
  194. dropout: float = 0.0,
  195. **kwargs: Unpack[TransformersKwargs],
  196. ):
  197. key_states = repeat_kv(key, module.num_key_value_groups)
  198. value_states = repeat_kv(value, module.num_key_value_groups)
  199. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  200. if attention_mask is not None:
  201. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  202. attn_weights = attn_weights + causal_mask
  203. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  204. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  205. attn_output = torch.matmul(attn_weights, value_states)
  206. attn_output = attn_output.transpose(1, 2).contiguous()
  207. return attn_output, attn_weights
  208. class JanusVisionAttention(nn.Module):
  209. """Attention Class for Janus Vision Encoder"""
  210. def __init__(self, config: JanusVisionConfig):
  211. super().__init__()
  212. self.config = config
  213. self.embed_dim = config.hidden_size
  214. self.num_heads = config.num_attention_heads
  215. self.head_dim = self.embed_dim // self.num_heads
  216. if self.head_dim * self.num_heads != self.embed_dim:
  217. raise ValueError(
  218. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  219. f" {self.num_heads})."
  220. )
  221. self.scale = self.head_dim**-0.5
  222. self.attention_dropout = config.attention_dropout
  223. proj_dropout = config.projection_dropout
  224. qk_norm = config.use_qk_norm
  225. self.is_causal = False
  226. # Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
  227. self.num_key_value_groups = 1
  228. self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  229. self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  230. self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  231. self.projection_layer = nn.Linear(self.embed_dim, self.embed_dim)
  232. self.projection_dropout = nn.Dropout(proj_dropout) if proj_dropout > 0 else nn.Identity()
  233. self.q_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity()
  234. self.k_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity()
  235. def forward(
  236. self,
  237. hidden_states: torch.Tensor,
  238. attention_mask: Optional[torch.Tensor] = None,
  239. **kwargs: Unpack[TransformersKwargs],
  240. ):
  241. batch_size, seq_len, _ = hidden_states.size()
  242. query_states = self.q_proj(hidden_states)
  243. key_states = self.k_proj(hidden_states)
  244. value_states = self.v_proj(hidden_states)
  245. query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
  246. query_states = self.q_norm(query_states)
  247. key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
  248. key_states = self.k_norm(key_states)
  249. query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  250. key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  251. value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  252. attention_interface: Callable = eager_attention_forward
  253. if self.config._attn_implementation != "eager":
  254. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  255. attn_output, attn_weights = attention_interface(
  256. self,
  257. query_states,
  258. key_states,
  259. value_states,
  260. attention_mask,
  261. dropout=0.0 if not self.training else self.attention_dropout,
  262. scaling=self.scale,
  263. is_causal=self.is_causal,
  264. **kwargs,
  265. )
  266. attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
  267. output = self.projection_layer(attn_output)
  268. output = self.projection_dropout(output)
  269. return output, attn_weights
  270. class JanusVisionMLP(nn.Module):
  271. def __init__(self, config: JanusVisionConfig):
  272. super().__init__()
  273. self.config = config
  274. self.intermediate_size = int(config.hidden_size * config.mlp_ratio)
  275. self.activation_fn = ACT2FN[config.hidden_act] # Gelu act
  276. self.fc1 = nn.Linear(config.hidden_size, self.intermediate_size)
  277. self.fc2 = nn.Linear(self.intermediate_size, config.hidden_size)
  278. self.dropout1 = nn.Dropout(config.hidden_dropout_rate)
  279. self.dropout2 = nn.Dropout(config.hidden_dropout_rate)
  280. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  281. hidden_states = self.fc1(hidden_states)
  282. hidden_states = self.activation_fn(hidden_states)
  283. hidden_states = self.dropout1(hidden_states)
  284. hidden_states = self.fc2(hidden_states)
  285. hidden_states = self.dropout2(hidden_states)
  286. return hidden_states
  287. class JanusVisionEncoderLayer(GradientCheckpointingLayer):
  288. def __init__(self, config: JanusVisionConfig):
  289. super().__init__()
  290. self.embed_dim = config.hidden_size
  291. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  292. self.self_attn = JanusVisionAttention(config)
  293. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  294. self.mlp = JanusVisionMLP(config)
  295. self.config = config
  296. @auto_docstring
  297. def forward(
  298. self,
  299. hidden_states: torch.Tensor,
  300. attention_mask: torch.Tensor,
  301. **kwargs: Unpack[TransformersKwargs],
  302. ) -> torch.FloatTensor:
  303. residual = hidden_states
  304. hidden_states = self.layer_norm1(hidden_states)
  305. hidden_states, _ = self.self_attn(
  306. hidden_states=hidden_states,
  307. attention_mask=attention_mask,
  308. **kwargs,
  309. )
  310. hidden_states = residual + hidden_states
  311. residual = hidden_states
  312. hidden_states = self.layer_norm2(hidden_states)
  313. hidden_states = self.mlp(hidden_states)
  314. hidden_states = residual + hidden_states
  315. return hidden_states
  316. class JanusVisionEncoder(nn.Module):
  317. """
  318. Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
  319. [`JanusVisionEncoderLayer`].
  320. Args:
  321. config: JanusVisionConfig
  322. """
  323. def __init__(self, config: JanusVisionConfig):
  324. super().__init__()
  325. self.config = config
  326. self.layers = nn.ModuleList([JanusVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  327. self.gradient_checkpointing = False
  328. # Ignore copy
  329. @auto_docstring
  330. def forward(
  331. self,
  332. inputs_embeds,
  333. attention_mask: Optional[torch.Tensor] = None,
  334. **kwargs: Unpack[TransformersKwargs],
  335. ) -> BaseModelOutput:
  336. hidden_states = inputs_embeds
  337. for encoder_layer in self.layers:
  338. hidden_states = encoder_layer(
  339. hidden_states,
  340. attention_mask,
  341. **kwargs,
  342. )
  343. return BaseModelOutput(last_hidden_state=hidden_states)
  344. class JanusAttention(nn.Module):
  345. """Multi-headed attention from 'Attention Is All You Need' paper"""
  346. def __init__(self, config):
  347. super().__init__()
  348. self.config = config
  349. self.embed_dim = config.hidden_size
  350. self.num_heads = config.num_attention_heads
  351. self.head_dim = self.embed_dim // self.num_heads
  352. if self.head_dim * self.num_heads != self.embed_dim:
  353. raise ValueError(
  354. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  355. f" {self.num_heads})."
  356. )
  357. self.scale = self.head_dim**-0.5
  358. self.is_causal = False
  359. self.attention_dropout = config.attention_dropout
  360. # small tweak here compared to CLIP, no bias here
  361. self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
  362. if config.qkv_bias:
  363. q_bias = nn.Parameter(torch.zeros(self.embed_dim))
  364. v_bias = nn.Parameter(torch.zeros(self.embed_dim))
  365. else:
  366. q_bias = None
  367. v_bias = None
  368. if q_bias is not None:
  369. qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
  370. self.qkv.bias = nn.Parameter(qkv_bias)
  371. self.projection = nn.Linear(self.embed_dim, self.embed_dim)
  372. def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
  373. return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
  374. def forward(
  375. self,
  376. hidden_states: torch.Tensor,
  377. head_mask: Optional[torch.Tensor] = None,
  378. **kwargs,
  379. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  380. """Input shape: Batch x Time x Channel"""
  381. bsz, tgt_len, embed_dim = hidden_states.size()
  382. mixed_qkv = self.qkv(hidden_states)
  383. mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
  384. 2, 0, 3, 1, 4
  385. )
  386. query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
  387. attention_interface: Callable = eager_attention_forward
  388. if self.config._attn_implementation != "eager":
  389. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  390. attn_output, attn_weights = attention_interface(
  391. self,
  392. query_states,
  393. key_states,
  394. value_states,
  395. attention_mask=None,
  396. dropout=0.0 if not self.training else self.attention_dropout,
  397. scaling=self.scale,
  398. **kwargs,
  399. )
  400. attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
  401. attn_output = self.projection(attn_output)
  402. return attn_output, attn_weights
  403. class JanusMLP(nn.Module):
  404. def __init__(self, config):
  405. super().__init__()
  406. self.config = config
  407. self.activation_fn = ACT2FN[config.hidden_act]
  408. self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
  409. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  410. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  411. hidden_states = self.fc1(hidden_states)
  412. hidden_states = self.activation_fn(hidden_states)
  413. hidden_states = self.fc2(hidden_states)
  414. return hidden_states
  415. class JanusEncoderLayer(GradientCheckpointingLayer):
  416. def __init__(self, config: JanusConfig):
  417. super().__init__()
  418. self.embed_dim = config.hidden_size
  419. self.self_attn = JanusAttention(config)
  420. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  421. self.mlp = JanusMLP(config)
  422. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  423. @auto_docstring
  424. def forward(
  425. self,
  426. hidden_states: torch.Tensor,
  427. attention_mask: torch.Tensor,
  428. **kwargs: Unpack[TransformersKwargs],
  429. ) -> torch.FloatTensor:
  430. residual = hidden_states
  431. hidden_states = self.layer_norm1(hidden_states)
  432. hidden_states, _ = self.self_attn(
  433. hidden_states=hidden_states,
  434. head_mask=attention_mask,
  435. **kwargs,
  436. )
  437. hidden_states = hidden_states + residual
  438. residual = hidden_states
  439. hidden_states = self.layer_norm2(hidden_states)
  440. hidden_states = self.mlp(hidden_states)
  441. hidden_states = hidden_states + residual
  442. return hidden_states
  443. @auto_docstring
  444. class JanusVisionModel(JanusPreTrainedModel):
  445. main_input_name = "pixel_values"
  446. config: JanusVisionConfig
  447. _can_record_outputs = {
  448. "hidden_states": JanusEncoderLayer,
  449. "attentions": JanusAttention,
  450. }
  451. def __init__(self, config: JanusVisionConfig):
  452. super().__init__(config)
  453. self.config = config
  454. embed_dim = config.hidden_size
  455. self.embeddings = JanusVisionEmbeddings(config)
  456. self.encoder = JanusVisionEncoder(config)
  457. self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
  458. self.post_init()
  459. @check_model_inputs(tie_last_hidden_states=False)
  460. @auto_docstring
  461. def forward(
  462. self,
  463. pixel_values: Optional[torch.FloatTensor] = None,
  464. interpolate_pos_encoding: bool = False,
  465. **kwargs: Unpack[TransformersKwargs],
  466. ) -> Union[tuple, BaseModelOutputWithPooling]:
  467. if pixel_values is None:
  468. raise ValueError("You have to specify pixel_values")
  469. hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
  470. encoder_outputs: BaseModelOutput = self.encoder(
  471. inputs_embeds=hidden_states,
  472. **kwargs,
  473. )
  474. last_hidden_state = encoder_outputs.last_hidden_state
  475. last_hidden_state = self.post_layernorm(last_hidden_state)
  476. pooled_output = last_hidden_state[:, 0, :]
  477. pooled_output = self.post_layernorm(pooled_output)
  478. return BaseModelOutputWithPooling(
  479. last_hidden_state=last_hidden_state,
  480. pooler_output=pooled_output,
  481. )
  482. def get_input_embeddings(self):
  483. return self.embeddings
  484. class JanusVisionAlignerMLP(nn.Module):
  485. def __init__(self, config: JanusVisionConfig):
  486. super().__init__()
  487. self.fc1 = nn.Linear(config.hidden_size, config.projection_dim)
  488. self.hidden_layers = nn.ModuleList(
  489. [nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.depth)]
  490. )
  491. self.activation_fn = ACT2FN[config.hidden_act]
  492. def forward(self, hidden_states):
  493. hidden_states = self.fc1(hidden_states)
  494. for layer in self.hidden_layers:
  495. hidden_states = self.activation_fn(hidden_states)
  496. hidden_states = layer(hidden_states)
  497. return hidden_states
  498. class JanusVQVAEVectorQuantizer(nn.Module):
  499. """
  500. A module for vector quantization using learned embedding vectors.
  501. This module implements the quantization process similar to te one described in
  502. the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
  503. input vectors into discrete codebook vectors, which are learned during training.
  504. Current implementation improves over previous ones by avoiding costly matrix multiplications
  505. and allowing for post-hoc remapping of indices.
  506. """
  507. def __init__(self, config: JanusVQVAEConfig):
  508. super().__init__()
  509. self.num_embeddings = config.num_embeddings
  510. self.embedding_dim = config.embed_dim
  511. self.beta = getattr(config, "beta", 0.25)
  512. self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
  513. self.quant_state_dims = [config.num_patches] * 2
  514. def forward(self, hidden_state: torch.Tensor):
  515. hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
  516. hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
  517. # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
  518. distances = (
  519. torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
  520. + torch.sum(self.embedding.weight**2, dim=1)
  521. - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, self.embedding.weight.transpose(0, 1))
  522. )
  523. min_encoding_indices = torch.argmin(distances, dim=1)
  524. hidden_state_quant = self.embedding(min_encoding_indices).view(hidden_state.shape)
  525. # compute loss for embedding
  526. loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean(
  527. (hidden_state_quant - hidden_state.detach()) ** 2
  528. )
  529. # preserve gradients
  530. hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
  531. # reshape back to match original input shape
  532. hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
  533. return hidden_state_quant, loss, min_encoding_indices
  534. def get_codebook_entry(self, image_tokens: torch.LongTensor) -> torch.FloatTensor:
  535. batch_size = image_tokens.shape[0]
  536. emb_dim: int = self.embedding.weight.shape[-1]
  537. # get quantized latent vectors
  538. hidden_state_quant = self.embedding(image_tokens)
  539. # l2 normalization on the last dimension
  540. hidden_state_quant = F.normalize(hidden_state_quant, p=2, dim=-1)
  541. # reshape back to match original input shape
  542. hidden_state_quant = hidden_state_quant.view((batch_size, *self.quant_state_dims, emb_dim))
  543. hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
  544. return hidden_state_quant
  545. class JanusVQVAEResnetBlock(nn.Module):
  546. def __init__(
  547. self,
  548. config,
  549. in_channels,
  550. out_channels=None,
  551. conv_shortcut=False,
  552. ):
  553. super().__init__()
  554. self.in_channels = in_channels
  555. self.out_channels = in_channels if out_channels is None else out_channels
  556. self.use_conv_shortcut = conv_shortcut
  557. self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
  558. self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
  559. self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
  560. self.dropout = torch.nn.Dropout(config.dropout)
  561. self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
  562. if self.in_channels != self.out_channels:
  563. if self.use_conv_shortcut:
  564. self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
  565. else:
  566. self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
  567. def forward(self, hidden_states):
  568. residual = hidden_states
  569. hidden_states = self.norm1(hidden_states)
  570. hidden_states *= torch.sigmoid(hidden_states)
  571. hidden_states = self.conv1(hidden_states)
  572. hidden_states = self.norm2(hidden_states)
  573. hidden_states *= torch.sigmoid(hidden_states)
  574. hidden_states = self.dropout(hidden_states)
  575. hidden_states = self.conv2(hidden_states)
  576. if self.in_channels != self.out_channels:
  577. if self.use_conv_shortcut:
  578. residual = self.conv_shortcut(residual)
  579. else:
  580. residual = self.nin_shortcut(residual)
  581. return residual + hidden_states
  582. class JanusVQVAEAttnBlock(nn.Module):
  583. def __init__(self, in_channels):
  584. super().__init__()
  585. self.in_channels = in_channels
  586. self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
  587. self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  588. self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  589. self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  590. self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  591. def forward(self, hidden_states):
  592. residual = hidden_states
  593. hidden_states = self.norm(hidden_states)
  594. query_states = self.q(hidden_states)
  595. key_states = self.k(hidden_states)
  596. value_states = self.v(hidden_states)
  597. # compute attention
  598. batch_size, channels, height, width = query_states.shape
  599. query_states = query_states.reshape(batch_size, channels, height * width).permute(0, 2, 1)
  600. key_states = key_states.reshape(batch_size, channels, height * width)
  601. attn_weights = torch.bmm(query_states, key_states)
  602. attn_weights = attn_weights * (int(channels) ** (-0.5))
  603. attn_weights = F.softmax(attn_weights, dim=2)
  604. # attend to values
  605. value_states = value_states.reshape(batch_size, channels, height * width)
  606. attn_weights = attn_weights.permute(0, 2, 1)
  607. attn_output = torch.bmm(value_states, attn_weights).reshape(batch_size, channels, height, width)
  608. attn_output = self.proj_out(attn_output)
  609. return residual + attn_output
  610. class JanusVQVAEConvDownsample(nn.Module):
  611. def __init__(self, in_channels):
  612. super().__init__()
  613. self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
  614. def forward(self, hidden_states):
  615. # no asymmetric padding in torch conv, must do it ourselves
  616. hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
  617. hidden_states = self.conv(hidden_states)
  618. return hidden_states
  619. class JanusVQVAEConvUpsample(nn.Module):
  620. def __init__(self, in_channels):
  621. super().__init__()
  622. self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
  623. def forward(self, hidden_states):
  624. hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
  625. hidden_states = self.conv(hidden_states)
  626. return hidden_states
  627. class JanusVQVAEMidBlock(nn.Module):
  628. def __init__(self, config: JanusVQVAEConfig, channels: int):
  629. super().__init__()
  630. self.block_1 = JanusVQVAEResnetBlock(
  631. config=config,
  632. in_channels=channels,
  633. out_channels=channels,
  634. )
  635. self.attn_1 = JanusVQVAEAttnBlock(channels)
  636. self.block_2 = JanusVQVAEResnetBlock(
  637. config=config,
  638. in_channels=channels,
  639. out_channels=channels,
  640. )
  641. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  642. hidden_states = self.block_1(hidden_states)
  643. hidden_states = self.attn_1(hidden_states)
  644. hidden_states = self.block_2(hidden_states)
  645. return hidden_states
  646. class JanusVQVAEEncoder(nn.Module):
  647. def __init__(self, config):
  648. super().__init__()
  649. self.num_resolutions = len(config.channel_multiplier)
  650. self.num_res_blocks = config.num_res_blocks
  651. base_channels = config.base_channels
  652. in_channels = config.in_channels
  653. double_latent = config.double_latent
  654. latent_channels = config.latent_channels
  655. channel_multiplier = config.channel_multiplier
  656. self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
  657. in_channel_multiplier = (1,) + tuple(channel_multiplier)
  658. self.in_channel_multiplier = in_channel_multiplier
  659. self.down = nn.ModuleList()
  660. for i_level in range(self.num_resolutions):
  661. block = nn.ModuleList()
  662. attn = nn.ModuleList()
  663. block_in = base_channels * in_channel_multiplier[i_level]
  664. block_out = base_channels * channel_multiplier[i_level]
  665. for i_block in range(self.num_res_blocks):
  666. block.append(
  667. JanusVQVAEResnetBlock(
  668. config=config,
  669. in_channels=block_in,
  670. out_channels=block_out,
  671. )
  672. )
  673. block_in = block_out
  674. if i_level == self.num_resolutions - 1:
  675. attn.append(JanusVQVAEAttnBlock(block_in))
  676. down = nn.Module()
  677. down.block = block
  678. down.attn = attn
  679. if i_level != self.num_resolutions - 1:
  680. down.downsample = JanusVQVAEConvDownsample(block_in)
  681. self.down.append(down)
  682. self.mid = JanusVQVAEMidBlock(config, block_in)
  683. self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
  684. self.conv_out = torch.nn.Conv2d(
  685. block_in,
  686. 2 * latent_channels if double_latent else latent_channels,
  687. kernel_size=3,
  688. stride=1,
  689. padding=1,
  690. )
  691. def forward(self, pixel_values: torch.LongTensor):
  692. # downsampling
  693. hidden_states = [self.conv_in(pixel_values)]
  694. for i_level in range(self.num_resolutions):
  695. for i_block in range(self.num_res_blocks):
  696. hidden_state = self.down[i_level].block[i_block](
  697. hidden_states[-1],
  698. )
  699. if len(self.down[i_level].attn) > 0:
  700. hidden_state = self.down[i_level].attn[i_block](hidden_state)
  701. hidden_states.append(hidden_state)
  702. if i_level != self.num_resolutions - 1:
  703. hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
  704. # middle
  705. last_hidden_state = hidden_states[-1]
  706. last_hidden_state = self.mid(last_hidden_state)
  707. # end
  708. last_hidden_state = self.norm_out(last_hidden_state)
  709. last_hidden_state *= torch.sigmoid(last_hidden_state)
  710. last_hidden_state = self.conv_out(last_hidden_state)
  711. return last_hidden_state
  712. class JanusVQVAEDecoder(nn.Module):
  713. def __init__(self, config):
  714. super().__init__()
  715. self.num_resolutions = len(config.channel_multiplier)
  716. self.num_res_blocks = config.num_res_blocks
  717. base_channels = config.base_channels
  718. latent_channels = config.latent_channels
  719. out_channels = config.out_channels
  720. # compute in_ch_mult, block_in and curr_res at lowest res
  721. block_in = base_channels * config.channel_multiplier[self.num_resolutions - 1]
  722. # z to block_in
  723. self.conv_in = torch.nn.Conv2d(latent_channels, block_in, kernel_size=3, stride=1, padding=1)
  724. # middle
  725. self.mid = JanusVQVAEMidBlock(config, block_in)
  726. # upsampling
  727. self.up = nn.ModuleList()
  728. for i_level in reversed(range(self.num_resolutions)):
  729. block = nn.ModuleList()
  730. attn = nn.ModuleList()
  731. block_out = base_channels * config.channel_multiplier[i_level]
  732. for i_block in range(self.num_res_blocks + 1):
  733. block.append(
  734. JanusVQVAEResnetBlock(
  735. config=config,
  736. in_channels=block_in,
  737. out_channels=block_out,
  738. )
  739. )
  740. block_in = block_out
  741. if i_level == self.num_resolutions - 1:
  742. attn.append(JanusVQVAEAttnBlock(block_in))
  743. up = nn.Module()
  744. up.block = block
  745. up.attn = attn
  746. if i_level != 0:
  747. up.upsample = JanusVQVAEConvUpsample(block_in)
  748. self.up.append(up)
  749. # end
  750. self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
  751. self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
  752. def forward(self, hidden_state: torch.FloatTensor) -> torch.FloatTensor:
  753. hidden_state = self.conv_in(hidden_state)
  754. # middle
  755. hidden_state = self.mid(hidden_state)
  756. # upsampling
  757. for i_level in range(self.num_resolutions):
  758. for i_block in range(self.num_res_blocks + 1):
  759. hidden_state = self.up[i_level].block[i_block](hidden_state)
  760. if len(self.up[i_level].attn) > 0:
  761. hidden_state = self.up[i_level].attn[i_block](hidden_state)
  762. if i_level != self.num_resolutions - 1:
  763. hidden_state = self.up[i_level].upsample(hidden_state)
  764. hidden_state = self.norm_out(hidden_state)
  765. hidden_state *= torch.sigmoid(hidden_state)
  766. hidden_state = self.conv_out(hidden_state)
  767. return hidden_state
  768. @auto_docstring(
  769. custom_intro="""
  770. The VQ-VAE model used in Janus for encoding/decoding images into discrete tokens.
  771. This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
  772. [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
  773. Taigman](https://huggingface.co/papers/2203.13131).
  774. """
  775. )
  776. class JanusVQVAE(JanusPreTrainedModel):
  777. config: JanusVQVAEConfig
  778. _no_split_modules = [
  779. "JanusVQVAEAttnBlock",
  780. "JanusVQVAEResnetBlock",
  781. "JanusVQVAEVectorQuantizer",
  782. ]
  783. main_input_name = "pixel_values"
  784. def __init__(self, config: JanusVQVAEConfig):
  785. super().__init__(config)
  786. self.encoder = JanusVQVAEEncoder(config)
  787. self.quantize = JanusVQVAEVectorQuantizer(config)
  788. self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
  789. self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
  790. self.eval() # Janus's VQ model is frozen
  791. self.decoder = JanusVQVAEDecoder(config)
  792. self.gradient_checkpointing = False
  793. # Initialize the VQVAE model.
  794. self.post_init()
  795. def encode(self, pixel_values: torch.LongTensor):
  796. hidden_states = self.encoder(pixel_values)
  797. hidden_states = self.quant_conv(hidden_states)
  798. quant, emb_loss, indices = self.quantize(hidden_states)
  799. return quant, emb_loss, indices
  800. def decode(self, image_tokens: torch.LongTensor) -> torch.FloatTensor:
  801. """
  802. Decodes quantized token IDs into pixel values.
  803. Args:
  804. image_tokens (torch.LongTensor): Batch of token IDs.
  805. Returns:
  806. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  807. Pixel values decoded from the token IDs.
  808. """
  809. if image_tokens.shape[1] != self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]:
  810. raise ValueError(
  811. f"Expected `image_tokens` to have shape `(batch_size, {self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]})`, "
  812. f"but got shape `{image_tokens.shape}`."
  813. )
  814. codebook_entry = self.quantize.get_codebook_entry(image_tokens)
  815. hidden_states = self.post_quant_conv(codebook_entry)
  816. pixel_values = self.decoder(hidden_states)
  817. return pixel_values
  818. @can_return_tuple
  819. @auto_docstring
  820. def forward(
  821. self,
  822. pixel_values: torch.FloatTensor,
  823. ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
  824. batch_size = pixel_values.shape[0]
  825. quant, embedding_loss, indices = self.encode(pixel_values)
  826. decoded_pixel_values = self.decode(indices.view(batch_size, -1))
  827. return JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
  828. class JanusVQVAEAlignerMLP(nn.Module):
  829. def __init__(self, config: JanusVQVAEConfig):
  830. super().__init__()
  831. self.fc1 = nn.Linear(config.embed_dim, config.projection_dim)
  832. self.hidden_layers = nn.ModuleList(
  833. [nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.num_hidden_layers)]
  834. )
  835. self.activation_fn = ACT2FN[config.hidden_act]
  836. def forward(self, hidden_states):
  837. hidden_states = self.fc1(hidden_states)
  838. for layer in self.hidden_layers:
  839. hidden_states = self.activation_fn(hidden_states)
  840. hidden_states = layer(hidden_states)
  841. return hidden_states
  842. class JanusVQVAEHead(nn.Module):
  843. """Head used for sampling tokens in image generation, replacing the usual lm head."""
  844. def __init__(self, config: JanusVQVAEConfig):
  845. super().__init__()
  846. self.proj_out = nn.Linear(config.image_token_embed_dim, config.projection_dim)
  847. self.activation_fn = ACT2FN[config.hidden_act]
  848. self.vision_head = nn.Linear(config.projection_dim, config.num_embeddings)
  849. def forward(self, hidden_states: torch.Tensor) -> torch.tensor:
  850. hidden_states = self.proj_out(hidden_states)
  851. hidden_states = self.activation_fn(hidden_states)
  852. hidden_states = self.vision_head(hidden_states)
  853. return hidden_states
  854. @auto_docstring(
  855. custom_intro="""
  856. The Janus model which consists of a siglip vision backbone, a Llama language model and a VQ model.
  857. """
  858. )
  859. class JanusModel(JanusPreTrainedModel):
  860. def __init__(self, config: JanusConfig):
  861. super().__init__(config)
  862. self.config = config
  863. # This is necessary for backward compatibility, see SiglipModel initialization
  864. self.vision_model = JanusVisionModel._from_config(config.vision_config)
  865. self.aligner = JanusVisionAlignerMLP(self.vision_model.config)
  866. self.vqmodel = JanusVQVAE._from_config(config.vq_config)
  867. # Below generation_* modules are used for Image generation.
  868. # Embeddings used for image generation, instead of Janus vision embeddings.
  869. self.generation_embeddings = nn.Embedding(self.vqmodel.config.num_embeddings, self.vqmodel.config.embed_dim)
  870. self.generation_aligner = JanusVQVAEAlignerMLP(self.vqmodel.config)
  871. self.generation_head = JanusVQVAEHead(self.vqmodel.config)
  872. self.language_model = AutoModel.from_config(config=config.text_config)
  873. self.gradient_checkpointing = False
  874. # Initialize weights and apply final processing.
  875. self.post_init()
  876. def get_input_embeddings(self):
  877. return self.language_model.get_input_embeddings()
  878. def set_input_embeddings(self, value):
  879. self.language_model.set_input_embeddings(value)
  880. def get_image_features(self, pixel_values):
  881. image_embeds = self.vision_model(pixel_values)
  882. image_embeds = self.aligner(image_embeds.last_hidden_state)
  883. return image_embeds
  884. def get_placeholder_mask(
  885. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  886. ):
  887. """
  888. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  889. equal to the length of multimodal features. If the lengths are different, an error is raised.
  890. """
  891. if input_ids is None:
  892. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  893. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  894. )
  895. special_image_mask = special_image_mask.all(-1)
  896. else:
  897. special_image_mask = input_ids == self.config.image_token_id
  898. n_image_tokens = special_image_mask.sum()
  899. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  900. if inputs_embeds[special_image_mask].numel() != image_features.numel():
  901. n_image_features = image_features.shape[0] * image_features.shape[1]
  902. raise ValueError(
  903. f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
  904. )
  905. return special_image_mask
  906. @can_return_tuple
  907. @auto_docstring
  908. def forward(
  909. self,
  910. input_ids: Optional[torch.LongTensor] = None,
  911. pixel_values: Optional[torch.FloatTensor] = None,
  912. attention_mask: Optional[torch.Tensor] = None,
  913. position_ids: Optional[torch.LongTensor] = None,
  914. past_key_values: Optional[Cache] = None,
  915. cache_position: Optional[torch.LongTensor] = None,
  916. inputs_embeds: Optional[torch.FloatTensor] = None,
  917. use_cache: Optional[bool] = None,
  918. logits_to_keep: Union[int, torch.Tensor] = 0,
  919. **kwargs,
  920. ):
  921. if (input_ids is None) ^ (inputs_embeds is not None):
  922. raise ValueError(
  923. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  924. )
  925. if inputs_embeds is None:
  926. inputs_embeds = self.get_input_embeddings()(input_ids)
  927. if pixel_values is not None:
  928. image_embeds = self.get_image_features(pixel_values)
  929. image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1])
  930. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  931. image_attention_mask = self.get_placeholder_mask(
  932. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  933. )
  934. inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
  935. lm_output = self.language_model(
  936. inputs_embeds=inputs_embeds,
  937. attention_mask=attention_mask,
  938. position_ids=position_ids,
  939. past_key_values=past_key_values,
  940. use_cache=use_cache,
  941. cache_position=cache_position,
  942. logits_to_keep=logits_to_keep,
  943. **kwargs,
  944. )
  945. return JanusBaseModelOutputWithPast(
  946. last_hidden_state=lm_output.last_hidden_state,
  947. past_key_values=lm_output.past_key_values,
  948. hidden_states=lm_output.hidden_states,
  949. attentions=lm_output.attentions,
  950. image_hidden_states=image_embeds if pixel_values is not None else None,
  951. )
  952. class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
  953. _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
  954. _can_compile_fullgraph = True
  955. def __init__(self, config: JanusConfig):
  956. super().__init__(config)
  957. self.config = config
  958. self.model = JanusModel(config)
  959. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  960. # Initialize weights and apply final processing.
  961. self.post_init()
  962. def get_input_embeddings(self):
  963. return self.model.language_model.get_input_embeddings()
  964. def set_input_embeddings(self, value):
  965. self.model.language_model.set_input_embeddings(value)
  966. def prepare_embeddings_for_image_generation(self, inputs: torch.Tensor) -> torch.Tensor:
  967. hidden_state = self.model.generation_embeddings(inputs)
  968. hidden_state = self.model.generation_aligner(hidden_state)
  969. return hidden_state
  970. @can_return_tuple
  971. @auto_docstring
  972. def forward(
  973. self,
  974. input_ids: Optional[torch.LongTensor] = None,
  975. pixel_values: Optional[torch.FloatTensor] = None,
  976. attention_mask: Optional[torch.Tensor] = None,
  977. position_ids: Optional[torch.LongTensor] = None,
  978. past_key_values: Optional[Cache] = None,
  979. cache_position: Optional[torch.LongTensor] = None,
  980. inputs_embeds: Optional[torch.FloatTensor] = None,
  981. labels: Optional[torch.LongTensor] = None,
  982. use_cache: Optional[bool] = None,
  983. logits_to_keep: Union[int, torch.Tensor] = 0,
  984. **kwargs: Unpack[TransformersKwargs],
  985. ):
  986. r"""
  987. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  988. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  989. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  990. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  991. """
  992. outputs = self.model(
  993. input_ids=input_ids,
  994. pixel_values=pixel_values,
  995. attention_mask=attention_mask,
  996. position_ids=position_ids,
  997. past_key_values=past_key_values,
  998. inputs_embeds=inputs_embeds,
  999. use_cache=use_cache,
  1000. cache_position=cache_position,
  1001. **kwargs,
  1002. )
  1003. hidden_states = outputs.last_hidden_state
  1004. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1005. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1006. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1007. loss = None
  1008. if labels is not None:
  1009. loss = self.loss_function(
  1010. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  1011. )
  1012. return JanusCausalLMOutputWithPast(
  1013. loss=loss,
  1014. logits=logits,
  1015. past_key_values=outputs.past_key_values,
  1016. hidden_states=outputs.hidden_states,
  1017. attentions=outputs.attentions,
  1018. image_hidden_states=outputs.image_hidden_states,
  1019. )
  1020. def prepare_inputs_for_generation(
  1021. self,
  1022. input_ids,
  1023. pixel_values=None,
  1024. past_key_values=None,
  1025. attention_mask=None,
  1026. inputs_embeds=None,
  1027. cache_position=None,
  1028. logits_to_keep=None,
  1029. **kwargs,
  1030. ):
  1031. # Overwritten -- extra custom processing
  1032. model_inputs = super().prepare_inputs_for_generation(
  1033. input_ids,
  1034. past_key_values=past_key_values,
  1035. inputs_embeds=inputs_embeds,
  1036. attention_mask=attention_mask,
  1037. cache_position=cache_position,
  1038. logits_to_keep=logits_to_keep,
  1039. **kwargs,
  1040. )
  1041. # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
  1042. # Otherwise we need pixel values to be passed to model
  1043. if cache_position[0] == 0:
  1044. model_inputs["pixel_values"] = pixel_values
  1045. return model_inputs
  1046. def decode_image_tokens(self, image_tokens: torch.Tensor):
  1047. """
  1048. Decodes generated image tokens from language model to continuous pixel values
  1049. with VQGAN module via upsampling.
  1050. Args:
  1051. image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
  1052. The tensors corresponding to the input images.
  1053. """
  1054. decoded_image = self.model.vqmodel.decode(image_tokens)
  1055. decoded_image = decoded_image.permute(0, 2, 3, 1)
  1056. return decoded_image
  1057. @torch.no_grad
  1058. def generate(
  1059. self,
  1060. inputs: Optional[torch.Tensor] = None,
  1061. attention_mask: Optional[torch.LongTensor] = None,
  1062. logits_processor: Optional[LogitsProcessorList] = None,
  1063. **kwargs,
  1064. ):
  1065. # 1. Handle generation config and model kwargs
  1066. generation_config = kwargs.pop("generation_config", self.generation_config)
  1067. generation_config = copy.deepcopy(generation_config)
  1068. # Default to "text" generation if mode isn't provided
  1069. generation_mode = kwargs.pop("generation_mode", "text")
  1070. if generation_mode == "text":
  1071. # Set guidance_scale=None to prevent running UnbatchedCFG processor.
  1072. return super().generate(
  1073. inputs=inputs,
  1074. attention_mask=attention_mask,
  1075. generation_config=generation_config,
  1076. guidance_scale=None,
  1077. **kwargs,
  1078. )
  1079. model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
  1080. # Validate generation mode
  1081. if generation_config.get_generation_mode() not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
  1082. raise ValueError(
  1083. "Got incompatible mode for Image Generation, should be one of greedy or sampling. "
  1084. "Ensure that beam search is de-activated by setting `num_beams=1`."
  1085. )
  1086. # Validate the configuration and model kwargs
  1087. generation_config.validate()
  1088. self._validate_model_kwargs(model_kwargs.copy())
  1089. # 2. Initialize logit processors
  1090. logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
  1091. # Set `use_cache=True` as we will be using input embeds for generation.
  1092. model_kwargs["use_cache"] = True
  1093. if generation_config.guidance_scale is None:
  1094. logger.warning("`guidance_scale` is required for CFG but not provided. Setting to default value of 5.")
  1095. generation_config.guidance_scale = 5
  1096. model_kwargs["guidance_scale"] = generation_config.guidance_scale
  1097. # 3. Prepare model inputs
  1098. input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
  1099. inputs, generation_config.bos_token_id, model_kwargs
  1100. )
  1101. dtype, device = input_ids.dtype, input_ids.device
  1102. if len(input_ids.shape) != 2:
  1103. raise ValueError(
  1104. f"Expected input ids of shape (batch_size, seq_len), but got {input_ids.shape}"
  1105. "Passing `inputs embeds` is not supported currently."
  1106. )
  1107. # Prepare special tokens which will be used generate internally.
  1108. kwargs_has_attention_mask = attention_mask is not None
  1109. self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
  1110. # 4. Add CFG processor along with user passed logit processor.
  1111. if generation_config.guidance_scale and generation_config.guidance_scale > 1:
  1112. logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
  1113. generation_config.guidance_scale = None # Reset to prevent processor duplication.
  1114. # 5. Prepare logits processor
  1115. logits_processor = self._get_logits_processor(
  1116. generation_config=generation_config,
  1117. input_ids_seq_length=input_ids.shape[1],
  1118. encoder_input_ids=input_ids,
  1119. prefix_allowed_tokens_fn=None,
  1120. logits_processor=logits_processor,
  1121. device=device,
  1122. )
  1123. # 6. Expand inputs for multiple image generations per prompt.
  1124. input_ids, model_kwargs = self._expand_inputs_for_generation(
  1125. input_ids=input_ids,
  1126. attention_mask=attention_mask,
  1127. expand_size=generation_config.num_return_sequences,
  1128. **model_kwargs,
  1129. )
  1130. # 7. Prepare input and model caches
  1131. num_image_tokens = self.model.vision_model.config.num_image_tokens
  1132. batch_size, seq_len = input_ids.shape
  1133. input_tokens = input_ids.repeat(2, 1) # Double batch size for conditional/unconditional logits
  1134. attention_mask = model_kwargs.pop("attention_mask", None)
  1135. attention_mask = attention_mask.repeat(2, 1)
  1136. model_kwargs["attention_mask"] = attention_mask
  1137. # Mask all the tokens that are neither BOS nor BOI with pad token in the unconditional logits.
  1138. mask = (input_tokens[batch_size:, :] != generation_config.bos_token_id) & (
  1139. input_tokens[batch_size:, :] != generation_config.generation_kwargs["boi_token_id"]
  1140. )
  1141. input_tokens[batch_size:, :].masked_fill_(mask, generation_config.pad_token_id)
  1142. inputs_embeds = self.get_input_embeddings()(input_tokens)
  1143. model_kwargs = self._get_initial_cache_position(seq_len, device, model_kwargs)
  1144. if model_kwargs.get("past_key_values", None) is None:
  1145. # Prepare cache if not provided.
  1146. model_kwargs["past_key_values"] = self._get_cache(
  1147. cache_implementation=generation_config.cache_implementation or "static",
  1148. # batch_size should account for both conditional/unconditional input; hence multiplied by 2.
  1149. batch_size=batch_size * 2,
  1150. # we should have at least a cache len of seq_len + num_image_tokens.
  1151. max_cache_len=max(generation_config.max_length, num_image_tokens + seq_len),
  1152. model_kwargs=model_kwargs,
  1153. )
  1154. # Placeholder for generated tokens.
  1155. generated_tokens = torch.zeros((batch_size, num_image_tokens), dtype=dtype, device=device)
  1156. # 8. init attention / hidden states / scores tuples
  1157. output_attentions = generation_config.output_attentions
  1158. output_hidden_states = generation_config.output_hidden_states
  1159. output_scores = generation_config.output_scores
  1160. output_logits = generation_config.output_logits
  1161. return_dict_in_generate = generation_config.return_dict_in_generate
  1162. raw_scores = () if (return_dict_in_generate and output_scores) else None
  1163. raw_logits = () if (return_dict_in_generate and output_logits) else None
  1164. decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
  1165. decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
  1166. for i in range(num_image_tokens):
  1167. model_inputs = self.prepare_inputs_for_generation(
  1168. inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs
  1169. )
  1170. model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device)
  1171. model_inputs["cache_position"] = model_inputs["cache_position"].to(inputs_embeds.device)
  1172. outputs = self.model.language_model(
  1173. **model_inputs,
  1174. output_attentions=output_attentions,
  1175. output_hidden_states=output_hidden_states,
  1176. )
  1177. # Update model_kwargs like cache_position for next generation.
  1178. model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
  1179. hidden_state = outputs.last_hidden_state[:, -1, :].clone()
  1180. # Generate scores using the generation head (Not using above defined LM Head)
  1181. scores = self.model.generation_head(hidden_state)
  1182. next_token_scores = logits_processor(input_ids, scores)
  1183. # Sample next token.
  1184. if generation_config.do_sample:
  1185. probs = torch.softmax(next_token_scores, dim=-1)
  1186. next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
  1187. else:
  1188. next_token = torch.argmax(next_token_scores, dim=-1)
  1189. generated_tokens[:, i] = next_token
  1190. # Prepare embeddings for the next step.
  1191. next_token = torch.cat([next_token, next_token])
  1192. next_token = next_token.unsqueeze(-1)
  1193. inputs_embeds = self.prepare_embeddings_for_image_generation(next_token)
  1194. if return_dict_in_generate:
  1195. if output_scores:
  1196. raw_scores += (scores,)
  1197. if output_logits:
  1198. raw_logits += (hidden_state.float(),)
  1199. if output_attentions:
  1200. decoder_attentions += outputs.attentions
  1201. if output_hidden_states:
  1202. decoder_hidden_states += outputs.hidden_states
  1203. if return_dict_in_generate:
  1204. return GenerateDecoderOnlyOutput(
  1205. sequences=generated_tokens,
  1206. scores=scores,
  1207. logits=raw_logits,
  1208. attentions=decoder_attentions,
  1209. hidden_states=decoder_hidden_states,
  1210. past_key_values=outputs.past_key_values,
  1211. )
  1212. else:
  1213. return generated_tokens
  1214. __all__ = ["JanusPreTrainedModel", "JanusForConditionalGeneration", "JanusModel", "JanusVQVAE", "JanusVisionModel"]