modular_janus.py 76 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758
  1. # coding=utf-8
  2. # Copyright 2025 Deepseek AI and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import copy
  16. from collections.abc import Iterable
  17. from dataclasses import dataclass
  18. from typing import Callable, Optional, Union
  19. import numpy as np
  20. import torch
  21. import torch.nn.functional as F
  22. import torch.utils.checkpoint
  23. from torch import nn
  24. from transformers.models.blip.image_processing_blip import BlipImageProcessor
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache
  27. from ...configuration_utils import PretrainedConfig
  28. from ...generation import ClassifierFreeGuidanceLogitsProcessor, GenerationMixin, GenerationMode, LogitsProcessorList
  29. from ...generation.utils import GenerateDecoderOnlyOutput
  30. from ...image_processing_utils import BatchFeature, get_size_dict
  31. from ...image_transforms import convert_to_rgb, resize, to_channel_dimension_format
  32. from ...image_utils import (
  33. ChannelDimension,
  34. ImageInput,
  35. PILImageResampling,
  36. get_image_size,
  37. infer_channel_dimension_format,
  38. is_scaled_image,
  39. make_flat_list_of_images,
  40. to_numpy_array,
  41. valid_images,
  42. validate_preprocess_arguments,
  43. )
  44. from ...modeling_outputs import ModelOutput
  45. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  46. from ...processing_utils import Unpack
  47. from ...utils import (
  48. TensorType,
  49. TransformersKwargs,
  50. auto_docstring,
  51. can_return_tuple,
  52. filter_out_non_signature_kwargs,
  53. is_vision_available,
  54. logging,
  55. )
  56. from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel
  57. from ..blip_2.modeling_blip_2 import Blip2VisionModel
  58. from ..chameleon.configuration_chameleon import ChameleonVQVAEConfig
  59. from ..chameleon.modeling_chameleon import (
  60. ChameleonVQVAE,
  61. ChameleonVQVAEEncoderAttnBlock,
  62. ChameleonVQVAEEncoderConvDownsample,
  63. ChameleonVQVAEEncoderResnetBlock,
  64. ChameleonVQVAEVectorQuantizer,
  65. )
  66. from ..idefics.modeling_idefics import IdeficsBaseModelOutputWithPast, IdeficsCausalLMOutputWithPast
  67. from ..llama.modeling_llama import eager_attention_forward
  68. from ..siglip.configuration_siglip import SiglipVisionConfig
  69. from ..siglip.modeling_siglip import SiglipEncoder, SiglipEncoderLayer, SiglipVisionEmbeddings
  70. if is_vision_available():
  71. import PIL
  72. logger = logging.get_logger(__name__)
  73. # General docstring
  74. class JanusVisionConfig(SiglipVisionConfig):
  75. r"""
  76. This is the configuration class to store the configuration of a [`JanusVisionModel`]. It is used to instantiate a
  77. `JanusVisionModel` according to the specified arguments, defining the model architecture.
  78. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  79. documentation from [`PretrainedConfig`] for more information.
  80. Args:
  81. hidden_size (`int`, *optional*, defaults to 1024):
  82. Dimensionality of the encoder layers and the pooler layer.
  83. num_hidden_layers (`int`, *optional*, defaults to 24):
  84. Number of hidden layers in the Transformer encoder.
  85. num_attention_heads (`int`, *optional*, defaults to 16):
  86. Number of attention heads for each attention layer in the Transformer encoder.
  87. num_channels (`int`, *optional*, defaults to 3):
  88. The number of input channels.
  89. patch_size (`int`, *optional*, defaults to 16):
  90. The size (resolution) of each patch.
  91. image_size (`int`, *optional*, defaults to 384):
  92. The size (resolution) of each image.
  93. attention_dropout (`float`, *optional*, defaults to 0.0):
  94. Dropout probability for attention weights.
  95. layer_norm_eps (`float`, *optional*, defaults to 1e-06):
  96. The epsilon used by the layer normalization layers.
  97. hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
  98. The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  99. `"relu"`, `"selu"`, and `"gelu_new"` are supported.
  100. mlp_ratio (`float`, *optional*, defaults to 4.0):
  101. Ratio of MLP hidden dimensionality to embedding dimensionality.
  102. attention_bias (`bool`, *optional*, defaults to `True`):
  103. Whether to add a bias to the queries, keys, and values in the attention layers.
  104. hidden_dropout_rate (`float`, *optional*, defaults to 0.0):
  105. The dropout probability for fully connected layers in the encoder.
  106. projection_dim (`int`, *optional*, defaults to 2048):
  107. Dimensionality of the MLP projection head.
  108. projection_dropout (`float`, *optional*, defaults to 0.0):
  109. Dropout probability for the projection layer.
  110. use_qk_norm (`bool`, *optional*, defaults to `False`):
  111. Whether to normalize the query and key matrices.
  112. initializer_range (`float`, *optional*, defaults to 0.02):
  113. The standard deviation of the truncated normal initializer for initializing all weight matrices.
  114. depth (`int`, *optional*, defaults to 2):
  115. Number of hidden layers in the aligner module.
  116. num_image_tokens (`int`, *optional*, defaults to 576):
  117. Number of image tokens.
  118. """
  119. model_type = "janus_vision_model"
  120. base_config_key = "vision_config"
  121. def __init__(
  122. self,
  123. hidden_size=1024,
  124. num_hidden_layers=24,
  125. num_attention_heads=16,
  126. num_channels=3,
  127. patch_size=16,
  128. image_size=384,
  129. attention_dropout=0.0,
  130. layer_norm_eps=1e-6,
  131. hidden_act="gelu",
  132. mlp_ratio=4.0,
  133. attention_bias=True,
  134. hidden_dropout_rate=0.0,
  135. projection_dim=2048,
  136. projection_dropout=0.0,
  137. use_qk_norm=False,
  138. initializer_range=0.02,
  139. depth=2,
  140. num_image_tokens=576,
  141. **kwargs,
  142. ):
  143. super().__init__(
  144. hidden_size=hidden_size,
  145. num_hidden_layers=num_hidden_layers,
  146. num_attention_heads=num_attention_heads,
  147. num_channels=num_channels,
  148. patch_size=patch_size,
  149. image_size=image_size,
  150. attention_dropout=attention_dropout,
  151. layer_norm_eps=layer_norm_eps,
  152. hidden_act=hidden_act,
  153. **kwargs,
  154. )
  155. del self.intermediate_size
  156. self.mlp_ratio = mlp_ratio
  157. self.attention_bias = attention_bias
  158. self.hidden_dropout_rate = hidden_dropout_rate
  159. self.projection_dim = projection_dim
  160. self.projection_dropout = projection_dropout
  161. self.use_qk_norm = use_qk_norm
  162. self.initializer_range = initializer_range
  163. self.depth = depth
  164. self.num_image_tokens = num_image_tokens
  165. class JanusVQVAEConfig(ChameleonVQVAEConfig):
  166. r"""
  167. This is the configuration class to store the configuration of a [`JanusVQVAEModel`]. It is used to instantiate a
  168. `JanusVQVAEModel` according to the specified arguments, defining the model architecture.
  169. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  170. documentation from [`PretrainedConfig`] for more information. Instantiating a
  171. configuration with the defaults will yield a similar configuration to the VQModel of the
  172. [deepseek-community/Janus-Pro-1B](https://huggingface.co/deepseek-community/Janus-Pro-1B).
  173. Args:
  174. embed_dim (`int`, *optional*, defaults to 8):
  175. Dimensionality of each embedding vector.
  176. num_embeddings (`int`, *optional*, defaults to 16384):
  177. Number of codebook embeddings.
  178. double_latent (`bool`, *optional*, defaults to `False`):
  179. Whether to use double z channels.
  180. latent_channels (`int`, *optional*, defaults to 256):
  181. Number of channels for the latent space.
  182. num_patches (`int`, *optional*, defaults to 32):
  183. Num of patches the input images can be divided into.
  184. in_channels (`int`, *optional*, defaults to 3):
  185. Number of input channels.
  186. out_channels (`int`, *optional*, defaults to 3):
  187. Number of out channels.
  188. base_channels (`int`, *optional*, defaults to 128):
  189. Base channel count.
  190. channel_multiplier (`list[int]`, *optional*, defaults to `[1, 1, 2, 2, 4]`):
  191. Channel multipliers for each resolution.
  192. num_res_blocks (`int`, *optional*, defaults to 2):
  193. Number of residual blocks.
  194. dropout (`float`, *optional*, defaults to 0.0):
  195. Dropout rate.
  196. initializer_range (`float`, *optional*, defaults to 0.02):
  197. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  198. projection_dim (`int`, *optional*, defaults to 2048):
  199. Dimensionality of the MLP projection head.
  200. num_hidden_layers (`int`, *optional*, defaults to 2):
  201. Number of hidden layers in VAVAE MLP Connecter module.
  202. hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
  203. The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  204. `"relu"`, `"silu"` and `"gelu_new"` are supported.
  205. image_token_embed_dim (`int`, *optional*, defaults to 2048):
  206. Dimension of image embeddings. It should be same as the dimensionality of text embeddings.
  207. """
  208. def __init__(
  209. self,
  210. embed_dim: int = 8,
  211. num_embeddings: int = 16384,
  212. double_latent: bool = False,
  213. latent_channels: int = 256,
  214. num_patches: int = 32,
  215. in_channels: int = 3,
  216. out_channels: int = 3,
  217. base_channels: int = 128,
  218. channel_multiplier: list[int] = [1, 1, 2, 2, 4],
  219. num_res_blocks: int = 2,
  220. dropout: float = 0.0,
  221. initializer_range=0.02,
  222. projection_dim=2048,
  223. num_hidden_layers=2,
  224. hidden_act="gelu",
  225. image_token_embed_dim=2048,
  226. **kwargs,
  227. ):
  228. super().__init__(
  229. embed_dim=embed_dim,
  230. num_embeddings=num_embeddings,
  231. double_latent=double_latent,
  232. latent_channels=latent_channels,
  233. in_channels=in_channels,
  234. base_channels=base_channels,
  235. channel_multiplier=channel_multiplier,
  236. num_res_blocks=num_res_blocks,
  237. dropout=dropout,
  238. initializer_range=initializer_range,
  239. **kwargs,
  240. )
  241. self.num_patches = num_patches
  242. self.out_channels = out_channels
  243. self.projection_dim = projection_dim
  244. self.num_hidden_layers = num_hidden_layers
  245. self.hidden_act = hidden_act
  246. self.image_token_embed_dim = image_token_embed_dim
  247. del self.resolution
  248. del self.attn_resolutions
  249. del self.attn_type
  250. class JanusConfig(PretrainedConfig):
  251. r"""
  252. This is the configuration class to store the configuration of a [`JanusModel`]. It is used to instantiate an
  253. Janus model according to the specified arguments, defining the model architecture. Instantiating a configuration
  254. with the defaults will yield a similar configuration to that of the Janus-1B or Janus-7B models.
  255. e.g. [deepseek-community/Janus-Pro-1B](https://huggingface.co/deepseek-community/Janus-Pro-1B) or
  256. [deepseek-community/Janus-Pro-7B](https://huggingface.co/deepseek-community/Janus-Pro-7B)
  257. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  258. documentation from [`PretrainedConfig`] for more information.
  259. Args:
  260. text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `LlamaConfig`):
  261. The config object or dictionary of the text backbone.
  262. vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `JanusVisionConfig`):
  263. The config object or dictionary of the vision backbone.
  264. vq_config (`Union[AutoConfig, dict]`, *optional*, defaults to `JanusVQVAEConfig`):
  265. The config object or dictionary of the VQVAE backbone.
  266. image_token_id (`int`, *optional*, defaults to 100581):
  267. Token index of a placeholder image token.
  268. Example:
  269. ```python
  270. >>> from transformers import JanusForConditionalGeneration, JanusConfig, JanusVisionConfig, JanusVQVAEConfig, LlamaConfig
  271. >>> # Initializing a Janus vision config
  272. >>> vision_config = JanusVisionConfig()
  273. >>> # Initializing a Llama config
  274. >>> text_config = LlamaConfig()
  275. >>> # Initializing a VQ config
  276. >>> vq_config = JanusVQVAEConfig()
  277. >>> # Initializing a Janus Pro 1B style configuration
  278. >>> configuration = JanusConfig(vision_config=vision_config, text_config=text_config, vq_config=vq_config)
  279. >>> # Initializing a model from the Janus Pro 1B style configuration
  280. >>> model = JanusForConditionalGeneration(configuration)
  281. >>> # Accessing the model configuration
  282. >>> configuration = model.config
  283. ```"""
  284. model_type = "janus"
  285. sub_configs = {
  286. "text_config": AutoConfig,
  287. "vision_config": JanusVisionConfig,
  288. "vq_config": JanusVQVAEConfig,
  289. }
  290. def __init__(
  291. self,
  292. text_config=None,
  293. vision_config=None,
  294. vq_config=None,
  295. image_token_id=100581,
  296. **kwargs,
  297. ):
  298. if isinstance(text_config, dict):
  299. text_config["model_type"] = text_config.get("model_type", "llama")
  300. self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
  301. elif text_config is None:
  302. logger.info("`text_config` is None. Initializing with default values")
  303. self.text_config = CONFIG_MAPPING["llama"]()
  304. elif isinstance(text_config, PretrainedConfig):
  305. self.text_config = text_config
  306. else:
  307. raise ValueError(
  308. f"Invalid type for `text_config`. Must be either `dict` or `LlamaConfig`."
  309. f" Type found: {type(text_config)}"
  310. )
  311. if vision_config is None:
  312. logger.info("`vision_config` is None. Initializing with default JanusVisionConfig values")
  313. self.vision_config = JanusVisionConfig()
  314. elif isinstance(vision_config, dict):
  315. self.vision_config = JanusVisionConfig(**vision_config)
  316. elif isinstance(vision_config, JanusVisionConfig):
  317. self.vision_config = vision_config
  318. else:
  319. raise ValueError(
  320. f"Invalid type for `vision_config`. Must be either `dict` or `JanusVisionConfig`."
  321. f" Type found: {type(vision_config)}"
  322. )
  323. if vq_config is None:
  324. logger.info("`vq_config` is None. Initializing with default JanusVQVAEConfig values")
  325. self.vq_config = JanusVQVAEConfig()
  326. elif isinstance(vq_config, dict):
  327. self.vq_config = JanusVQVAEConfig(**vq_config)
  328. elif isinstance(vq_config, JanusVQVAEConfig):
  329. self.vq_config = vq_config
  330. else:
  331. raise ValueError(
  332. f"Invalid type for `vq_config`. Must be either `dict` or `JanusVQVAEConfig`."
  333. f" Type found: {type(vq_config)}"
  334. )
  335. self.initializer_range = self.vision_config.initializer_range
  336. # This dimension is required when decoding discrete image tokens to continuous input.
  337. self.vq_config.num_patches = self.vision_config.image_size // self.vision_config.patch_size
  338. # The default is only the index for the 1B model, 7B uses a different one
  339. self.image_token_id = image_token_id
  340. super().__init__(**kwargs)
  341. @auto_docstring
  342. class JanusPreTrainedModel(PreTrainedModel):
  343. config: JanusConfig
  344. base_model_prefix = "model"
  345. supports_gradient_checkpointing = True
  346. _no_split_modules = ["LlamaDecoderLayer", "JanusVisionEncoderLayer"]
  347. _skip_keys_device_placement = ["past_key_values", "causal_mask"]
  348. _supports_flash_attn = True
  349. _supports_sdpa = True
  350. _can_compile_fullgraph = True
  351. _supports_param_buffer_assignment = False
  352. @dataclass
  353. @auto_docstring(
  354. custom_intro="""
  355. Base class for Janus VQ-VAE mode model outputs.
  356. """
  357. )
  358. class JanusVQVAEOutput(ModelOutput):
  359. r"""
  360. decoded_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  361. Reconstructed pixel values after encoding and decoding the input.
  362. embedding_loss (`torch.FloatTensor`):
  363. Embedding loss.
  364. """
  365. decoded_pixel_values: Optional[torch.FloatTensor] = None
  366. embedding_loss: Optional[torch.FloatTensor] = None
  367. class JanusBaseModelOutputWithPast(IdeficsBaseModelOutputWithPast):
  368. pass
  369. class JanusCausalLMOutputWithPast(IdeficsCausalLMOutputWithPast):
  370. pass
  371. class JanusVisionEmbeddings(SiglipVisionEmbeddings):
  372. def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
  373. _, _, height, width = pixel_values.shape
  374. target_dtype = self.patch_embedding.weight.dtype
  375. patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
  376. embeddings = patch_embeds.flatten(2).transpose(1, 2)
  377. if interpolate_pos_encoding:
  378. pos_embeds = self.interpolate_pos_encoding(embeddings, height, width)
  379. else:
  380. pos_embeds = self.position_embedding(self.position_ids)
  381. embeddings = embeddings + pos_embeds
  382. return embeddings
  383. class JanusVisionAttention(nn.Module):
  384. """Attention Class for Janus Vision Encoder"""
  385. def __init__(self, config: JanusVisionConfig):
  386. super().__init__()
  387. self.config = config
  388. self.embed_dim = config.hidden_size
  389. self.num_heads = config.num_attention_heads
  390. self.head_dim = self.embed_dim // self.num_heads
  391. if self.head_dim * self.num_heads != self.embed_dim:
  392. raise ValueError(
  393. f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
  394. f" {self.num_heads})."
  395. )
  396. self.scale = self.head_dim**-0.5
  397. self.attention_dropout = config.attention_dropout
  398. proj_dropout = config.projection_dropout
  399. qk_norm = config.use_qk_norm
  400. self.is_causal = False
  401. # Janus has no MHA, hence for `eager_attention_forward` call setting `num_key_value_groups` to 1.
  402. self.num_key_value_groups = 1
  403. self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  404. self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  405. self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=config.attention_bias)
  406. self.projection_layer = nn.Linear(self.embed_dim, self.embed_dim)
  407. self.projection_dropout = nn.Dropout(proj_dropout) if proj_dropout > 0 else nn.Identity()
  408. self.q_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity()
  409. self.k_norm = nn.LayerNorm(self.embed_dim) if qk_norm else nn.Identity()
  410. def forward(
  411. self,
  412. hidden_states: torch.Tensor,
  413. attention_mask: Optional[torch.Tensor] = None,
  414. **kwargs: Unpack[TransformersKwargs],
  415. ):
  416. batch_size, seq_len, _ = hidden_states.size()
  417. query_states = self.q_proj(hidden_states)
  418. key_states = self.k_proj(hidden_states)
  419. value_states = self.v_proj(hidden_states)
  420. query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
  421. query_states = self.q_norm(query_states)
  422. key_states = key_states.reshape(-1, self.num_heads, self.head_dim)
  423. key_states = self.k_norm(key_states)
  424. query_states = query_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  425. key_states = key_states.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  426. value_states = value_states.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
  427. attention_interface: Callable = eager_attention_forward
  428. if self.config._attn_implementation != "eager":
  429. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  430. attn_output, attn_weights = attention_interface(
  431. self,
  432. query_states,
  433. key_states,
  434. value_states,
  435. attention_mask,
  436. dropout=0.0 if not self.training else self.attention_dropout,
  437. scaling=self.scale,
  438. is_causal=self.is_causal,
  439. **kwargs,
  440. )
  441. attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
  442. output = self.projection_layer(attn_output)
  443. output = self.projection_dropout(output)
  444. return output, attn_weights
  445. class JanusVisionMLP(nn.Module):
  446. def __init__(self, config: JanusVisionConfig):
  447. super().__init__()
  448. self.config = config
  449. self.intermediate_size = int(config.hidden_size * config.mlp_ratio)
  450. self.activation_fn = ACT2FN[config.hidden_act] # Gelu act
  451. self.fc1 = nn.Linear(config.hidden_size, self.intermediate_size)
  452. self.fc2 = nn.Linear(self.intermediate_size, config.hidden_size)
  453. self.dropout1 = nn.Dropout(config.hidden_dropout_rate)
  454. self.dropout2 = nn.Dropout(config.hidden_dropout_rate)
  455. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  456. hidden_states = self.fc1(hidden_states)
  457. hidden_states = self.activation_fn(hidden_states)
  458. hidden_states = self.dropout1(hidden_states)
  459. hidden_states = self.fc2(hidden_states)
  460. hidden_states = self.dropout2(hidden_states)
  461. return hidden_states
  462. class JanusVisionEncoderLayer(SiglipEncoderLayer):
  463. def __init__(self, config: JanusVisionConfig):
  464. super().__init__(config)
  465. self.config = config
  466. self.embed_dim = config.hidden_size
  467. self.self_attn = JanusVisionAttention(config)
  468. self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  469. self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
  470. self.mlp = JanusVisionMLP(config)
  471. class JanusVisionEncoder(SiglipEncoder):
  472. def __init__(self, config: JanusVisionConfig):
  473. super().__init__(config)
  474. self.layers = nn.ModuleList([JanusVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
  475. class JanusVisionModel(Blip2VisionModel):
  476. def __init__(self, config: JanusVisionConfig):
  477. super().__init__(config)
  478. self.encoder = JanusVisionEncoder(config)
  479. class JanusVisionAlignerMLP(nn.Module):
  480. def __init__(self, config: JanusVisionConfig):
  481. super().__init__()
  482. self.fc1 = nn.Linear(config.hidden_size, config.projection_dim)
  483. self.hidden_layers = nn.ModuleList(
  484. [nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.depth)]
  485. )
  486. self.activation_fn = ACT2FN[config.hidden_act]
  487. def forward(self, hidden_states):
  488. hidden_states = self.fc1(hidden_states)
  489. for layer in self.hidden_layers:
  490. hidden_states = self.activation_fn(hidden_states)
  491. hidden_states = layer(hidden_states)
  492. return hidden_states
  493. class JanusVQVAEVectorQuantizer(ChameleonVQVAEVectorQuantizer):
  494. def __init__(self, config: JanusVQVAEConfig):
  495. super().__init__(config)
  496. self.quant_state_dims = [config.num_patches] * 2
  497. def get_codebook_entry(self, image_tokens: torch.LongTensor) -> torch.FloatTensor:
  498. batch_size = image_tokens.shape[0]
  499. emb_dim: int = self.embedding.weight.shape[-1]
  500. # get quantized latent vectors
  501. hidden_state_quant = self.embedding(image_tokens)
  502. # l2 normalization on the last dimension
  503. hidden_state_quant = F.normalize(hidden_state_quant, p=2, dim=-1)
  504. # reshape back to match original input shape
  505. hidden_state_quant = hidden_state_quant.view((batch_size, *self.quant_state_dims, emb_dim))
  506. hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
  507. return hidden_state_quant
  508. class JanusVQVAEResnetBlock(ChameleonVQVAEEncoderResnetBlock):
  509. pass
  510. class JanusVQVAEAttnBlock(ChameleonVQVAEEncoderAttnBlock):
  511. pass
  512. class JanusVQVAEConvDownsample(ChameleonVQVAEEncoderConvDownsample):
  513. pass
  514. class JanusVQVAEConvUpsample(nn.Module):
  515. def __init__(self, in_channels):
  516. super().__init__()
  517. self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
  518. def forward(self, hidden_states):
  519. hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
  520. hidden_states = self.conv(hidden_states)
  521. return hidden_states
  522. class JanusVQVAEMidBlock(nn.Module):
  523. def __init__(self, config: JanusVQVAEConfig, channels: int):
  524. super().__init__()
  525. self.block_1 = JanusVQVAEResnetBlock(
  526. config=config,
  527. in_channels=channels,
  528. out_channels=channels,
  529. )
  530. self.attn_1 = JanusVQVAEAttnBlock(channels)
  531. self.block_2 = JanusVQVAEResnetBlock(
  532. config=config,
  533. in_channels=channels,
  534. out_channels=channels,
  535. )
  536. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  537. hidden_states = self.block_1(hidden_states)
  538. hidden_states = self.attn_1(hidden_states)
  539. hidden_states = self.block_2(hidden_states)
  540. return hidden_states
  541. class JanusVQVAEEncoder(nn.Module):
  542. def __init__(self, config):
  543. super().__init__()
  544. self.num_resolutions = len(config.channel_multiplier)
  545. self.num_res_blocks = config.num_res_blocks
  546. base_channels = config.base_channels
  547. in_channels = config.in_channels
  548. double_latent = config.double_latent
  549. latent_channels = config.latent_channels
  550. channel_multiplier = config.channel_multiplier
  551. self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
  552. in_channel_multiplier = (1,) + tuple(channel_multiplier)
  553. self.in_channel_multiplier = in_channel_multiplier
  554. self.down = nn.ModuleList()
  555. for i_level in range(self.num_resolutions):
  556. block = nn.ModuleList()
  557. attn = nn.ModuleList()
  558. block_in = base_channels * in_channel_multiplier[i_level]
  559. block_out = base_channels * channel_multiplier[i_level]
  560. for i_block in range(self.num_res_blocks):
  561. block.append(
  562. JanusVQVAEResnetBlock(
  563. config=config,
  564. in_channels=block_in,
  565. out_channels=block_out,
  566. )
  567. )
  568. block_in = block_out
  569. if i_level == self.num_resolutions - 1:
  570. attn.append(JanusVQVAEAttnBlock(block_in))
  571. down = nn.Module()
  572. down.block = block
  573. down.attn = attn
  574. if i_level != self.num_resolutions - 1:
  575. down.downsample = JanusVQVAEConvDownsample(block_in)
  576. self.down.append(down)
  577. self.mid = JanusVQVAEMidBlock(config, block_in)
  578. self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
  579. self.conv_out = torch.nn.Conv2d(
  580. block_in,
  581. 2 * latent_channels if double_latent else latent_channels,
  582. kernel_size=3,
  583. stride=1,
  584. padding=1,
  585. )
  586. def forward(self, pixel_values: torch.LongTensor):
  587. # downsampling
  588. hidden_states = [self.conv_in(pixel_values)]
  589. for i_level in range(self.num_resolutions):
  590. for i_block in range(self.num_res_blocks):
  591. hidden_state = self.down[i_level].block[i_block](
  592. hidden_states[-1],
  593. )
  594. if len(self.down[i_level].attn) > 0:
  595. hidden_state = self.down[i_level].attn[i_block](hidden_state)
  596. hidden_states.append(hidden_state)
  597. if i_level != self.num_resolutions - 1:
  598. hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
  599. # middle
  600. last_hidden_state = hidden_states[-1]
  601. last_hidden_state = self.mid(last_hidden_state)
  602. # end
  603. last_hidden_state = self.norm_out(last_hidden_state)
  604. last_hidden_state *= torch.sigmoid(last_hidden_state)
  605. last_hidden_state = self.conv_out(last_hidden_state)
  606. return last_hidden_state
  607. class JanusVQVAEDecoder(nn.Module):
  608. def __init__(self, config):
  609. super().__init__()
  610. self.num_resolutions = len(config.channel_multiplier)
  611. self.num_res_blocks = config.num_res_blocks
  612. base_channels = config.base_channels
  613. latent_channels = config.latent_channels
  614. out_channels = config.out_channels
  615. # compute in_ch_mult, block_in and curr_res at lowest res
  616. block_in = base_channels * config.channel_multiplier[self.num_resolutions - 1]
  617. # z to block_in
  618. self.conv_in = torch.nn.Conv2d(latent_channels, block_in, kernel_size=3, stride=1, padding=1)
  619. # middle
  620. self.mid = JanusVQVAEMidBlock(config, block_in)
  621. # upsampling
  622. self.up = nn.ModuleList()
  623. for i_level in reversed(range(self.num_resolutions)):
  624. block = nn.ModuleList()
  625. attn = nn.ModuleList()
  626. block_out = base_channels * config.channel_multiplier[i_level]
  627. for i_block in range(self.num_res_blocks + 1):
  628. block.append(
  629. JanusVQVAEResnetBlock(
  630. config=config,
  631. in_channels=block_in,
  632. out_channels=block_out,
  633. )
  634. )
  635. block_in = block_out
  636. if i_level == self.num_resolutions - 1:
  637. attn.append(JanusVQVAEAttnBlock(block_in))
  638. up = nn.Module()
  639. up.block = block
  640. up.attn = attn
  641. if i_level != 0:
  642. up.upsample = JanusVQVAEConvUpsample(block_in)
  643. self.up.append(up)
  644. # end
  645. self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
  646. self.conv_out = torch.nn.Conv2d(block_in, out_channels, kernel_size=3, stride=1, padding=1)
  647. def forward(self, hidden_state: torch.FloatTensor) -> torch.FloatTensor:
  648. hidden_state = self.conv_in(hidden_state)
  649. # middle
  650. hidden_state = self.mid(hidden_state)
  651. # upsampling
  652. for i_level in range(self.num_resolutions):
  653. for i_block in range(self.num_res_blocks + 1):
  654. hidden_state = self.up[i_level].block[i_block](hidden_state)
  655. if len(self.up[i_level].attn) > 0:
  656. hidden_state = self.up[i_level].attn[i_block](hidden_state)
  657. if i_level != self.num_resolutions - 1:
  658. hidden_state = self.up[i_level].upsample(hidden_state)
  659. hidden_state = self.norm_out(hidden_state)
  660. hidden_state *= torch.sigmoid(hidden_state)
  661. hidden_state = self.conv_out(hidden_state)
  662. return hidden_state
  663. class JanusVQVAE(ChameleonVQVAE):
  664. _no_split_modules = [
  665. "JanusVQVAEAttnBlock",
  666. "JanusVQVAEResnetBlock",
  667. "JanusVQVAEVectorQuantizer",
  668. ]
  669. main_input_name = "pixel_values"
  670. def __init__(self, config: JanusVQVAEConfig):
  671. super().__init__(config)
  672. self.decoder = JanusVQVAEDecoder(config)
  673. self.gradient_checkpointing = False
  674. # Initialize the VQVAE model.
  675. self.post_init()
  676. def decode(self, image_tokens: torch.LongTensor) -> torch.FloatTensor:
  677. """
  678. Decodes quantized token IDs into pixel values.
  679. Args:
  680. image_tokens (torch.LongTensor): Batch of token IDs.
  681. Returns:
  682. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  683. Pixel values decoded from the token IDs.
  684. """
  685. if image_tokens.shape[1] != self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]:
  686. raise ValueError(
  687. f"Expected `image_tokens` to have shape `(batch_size, {self.quantize.quant_state_dims[0] * self.quantize.quant_state_dims[1]})`, "
  688. f"but got shape `{image_tokens.shape}`."
  689. )
  690. codebook_entry = self.quantize.get_codebook_entry(image_tokens)
  691. hidden_states = self.post_quant_conv(codebook_entry)
  692. pixel_values = self.decoder(hidden_states)
  693. return pixel_values
  694. @can_return_tuple
  695. @auto_docstring
  696. def forward(
  697. self,
  698. pixel_values: torch.FloatTensor,
  699. ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
  700. batch_size = pixel_values.shape[0]
  701. quant, embedding_loss, indices = self.encode(pixel_values)
  702. decoded_pixel_values = self.decode(indices.view(batch_size, -1))
  703. return JanusVQVAEOutput(decoded_pixel_values, embedding_loss)
  704. class JanusVQVAEAlignerMLP(nn.Module):
  705. def __init__(self, config: JanusVQVAEConfig):
  706. super().__init__()
  707. self.fc1 = nn.Linear(config.embed_dim, config.projection_dim)
  708. self.hidden_layers = nn.ModuleList(
  709. [nn.Linear(config.projection_dim, config.projection_dim) for _ in range(1, config.num_hidden_layers)]
  710. )
  711. self.activation_fn = ACT2FN[config.hidden_act]
  712. def forward(self, hidden_states):
  713. hidden_states = self.fc1(hidden_states)
  714. for layer in self.hidden_layers:
  715. hidden_states = self.activation_fn(hidden_states)
  716. hidden_states = layer(hidden_states)
  717. return hidden_states
  718. class JanusVQVAEHead(nn.Module):
  719. """Head used for sampling tokens in image generation, replacing the usual lm head."""
  720. def __init__(self, config: JanusVQVAEConfig):
  721. super().__init__()
  722. self.proj_out = nn.Linear(config.image_token_embed_dim, config.projection_dim)
  723. self.activation_fn = ACT2FN[config.hidden_act]
  724. self.vision_head = nn.Linear(config.projection_dim, config.num_embeddings)
  725. def forward(self, hidden_states: torch.Tensor) -> torch.tensor:
  726. hidden_states = self.proj_out(hidden_states)
  727. hidden_states = self.activation_fn(hidden_states)
  728. hidden_states = self.vision_head(hidden_states)
  729. return hidden_states
  730. @auto_docstring(
  731. custom_intro="""
  732. The Janus model which consists of a siglip vision backbone, a Llama language model and a VQ model.
  733. """
  734. )
  735. class JanusModel(JanusPreTrainedModel):
  736. def __init__(self, config: JanusConfig):
  737. super().__init__(config)
  738. self.config = config
  739. # This is necessary for backward compatibility, see SiglipModel initialization
  740. self.vision_model = JanusVisionModel._from_config(config.vision_config)
  741. self.aligner = JanusVisionAlignerMLP(self.vision_model.config)
  742. self.vqmodel = JanusVQVAE._from_config(config.vq_config)
  743. # Below generation_* modules are used for Image generation.
  744. # Embeddings used for image generation, instead of Janus vision embeddings.
  745. self.generation_embeddings = nn.Embedding(self.vqmodel.config.num_embeddings, self.vqmodel.config.embed_dim)
  746. self.generation_aligner = JanusVQVAEAlignerMLP(self.vqmodel.config)
  747. self.generation_head = JanusVQVAEHead(self.vqmodel.config)
  748. self.language_model = AutoModel.from_config(config=config.text_config)
  749. self.gradient_checkpointing = False
  750. # Initialize weights and apply final processing.
  751. self.post_init()
  752. def get_input_embeddings(self):
  753. return self.language_model.get_input_embeddings()
  754. def set_input_embeddings(self, value):
  755. self.language_model.set_input_embeddings(value)
  756. def get_image_features(self, pixel_values):
  757. image_embeds = self.vision_model(pixel_values)
  758. image_embeds = self.aligner(image_embeds.last_hidden_state)
  759. return image_embeds
  760. def get_placeholder_mask(
  761. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  762. ):
  763. """
  764. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  765. equal to the length of multimodal features. If the lengths are different, an error is raised.
  766. """
  767. if input_ids is None:
  768. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  769. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  770. )
  771. special_image_mask = special_image_mask.all(-1)
  772. else:
  773. special_image_mask = input_ids == self.config.image_token_id
  774. n_image_tokens = special_image_mask.sum()
  775. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  776. if inputs_embeds[special_image_mask].numel() != image_features.numel():
  777. n_image_features = image_features.shape[0] * image_features.shape[1]
  778. raise ValueError(
  779. f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
  780. )
  781. return special_image_mask
  782. @can_return_tuple
  783. @auto_docstring
  784. def forward(
  785. self,
  786. input_ids: Optional[torch.LongTensor] = None,
  787. pixel_values: Optional[torch.FloatTensor] = None,
  788. attention_mask: Optional[torch.Tensor] = None,
  789. position_ids: Optional[torch.LongTensor] = None,
  790. past_key_values: Optional[Cache] = None,
  791. cache_position: Optional[torch.LongTensor] = None,
  792. inputs_embeds: Optional[torch.FloatTensor] = None,
  793. use_cache: Optional[bool] = None,
  794. logits_to_keep: Union[int, torch.Tensor] = 0,
  795. **kwargs,
  796. ):
  797. if (input_ids is None) ^ (inputs_embeds is not None):
  798. raise ValueError(
  799. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  800. )
  801. if inputs_embeds is None:
  802. inputs_embeds = self.get_input_embeddings()(input_ids)
  803. if pixel_values is not None:
  804. image_embeds = self.get_image_features(pixel_values)
  805. image_features = image_embeds.reshape(-1, inputs_embeds.shape[-1])
  806. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  807. image_attention_mask = self.get_placeholder_mask(
  808. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  809. )
  810. inputs_embeds = inputs_embeds.masked_scatter(image_attention_mask, image_features)
  811. lm_output = self.language_model(
  812. inputs_embeds=inputs_embeds,
  813. attention_mask=attention_mask,
  814. position_ids=position_ids,
  815. past_key_values=past_key_values,
  816. use_cache=use_cache,
  817. cache_position=cache_position,
  818. logits_to_keep=logits_to_keep,
  819. **kwargs,
  820. )
  821. return JanusBaseModelOutputWithPast(
  822. last_hidden_state=lm_output.last_hidden_state,
  823. past_key_values=lm_output.past_key_values,
  824. hidden_states=lm_output.hidden_states,
  825. attentions=lm_output.attentions,
  826. image_hidden_states=image_embeds if pixel_values is not None else None,
  827. )
  828. class JanusForConditionalGeneration(JanusPreTrainedModel, GenerationMixin):
  829. _tied_weights_keys = ["model.language_model.embed_tokens.weight", "lm_head.weight"]
  830. _can_compile_fullgraph = True
  831. def __init__(self, config: JanusConfig):
  832. super().__init__(config)
  833. self.config = config
  834. self.model = JanusModel(config)
  835. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  836. # Initialize weights and apply final processing.
  837. self.post_init()
  838. def get_input_embeddings(self):
  839. return self.model.language_model.get_input_embeddings()
  840. def set_input_embeddings(self, value):
  841. self.model.language_model.set_input_embeddings(value)
  842. def prepare_embeddings_for_image_generation(self, inputs: torch.Tensor) -> torch.Tensor:
  843. hidden_state = self.model.generation_embeddings(inputs)
  844. hidden_state = self.model.generation_aligner(hidden_state)
  845. return hidden_state
  846. @can_return_tuple
  847. @auto_docstring
  848. def forward(
  849. self,
  850. input_ids: Optional[torch.LongTensor] = None,
  851. pixel_values: Optional[torch.FloatTensor] = None,
  852. attention_mask: Optional[torch.Tensor] = None,
  853. position_ids: Optional[torch.LongTensor] = None,
  854. past_key_values: Optional[Cache] = None,
  855. cache_position: Optional[torch.LongTensor] = None,
  856. inputs_embeds: Optional[torch.FloatTensor] = None,
  857. labels: Optional[torch.LongTensor] = None,
  858. use_cache: Optional[bool] = None,
  859. logits_to_keep: Union[int, torch.Tensor] = 0,
  860. **kwargs: Unpack[TransformersKwargs],
  861. ):
  862. r"""
  863. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  864. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  865. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  866. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  867. """
  868. outputs = self.model(
  869. input_ids=input_ids,
  870. pixel_values=pixel_values,
  871. attention_mask=attention_mask,
  872. position_ids=position_ids,
  873. past_key_values=past_key_values,
  874. inputs_embeds=inputs_embeds,
  875. use_cache=use_cache,
  876. cache_position=cache_position,
  877. **kwargs,
  878. )
  879. hidden_states = outputs.last_hidden_state
  880. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  881. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  882. logits = self.lm_head(hidden_states[:, slice_indices, :])
  883. loss = None
  884. if labels is not None:
  885. loss = self.loss_function(
  886. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  887. )
  888. return JanusCausalLMOutputWithPast(
  889. loss=loss,
  890. logits=logits,
  891. past_key_values=outputs.past_key_values,
  892. hidden_states=outputs.hidden_states,
  893. attentions=outputs.attentions,
  894. image_hidden_states=outputs.image_hidden_states,
  895. )
  896. def prepare_inputs_for_generation(
  897. self,
  898. input_ids,
  899. pixel_values=None,
  900. past_key_values=None,
  901. attention_mask=None,
  902. inputs_embeds=None,
  903. cache_position=None,
  904. logits_to_keep=None,
  905. **kwargs,
  906. ):
  907. # Overwritten -- extra custom processing
  908. model_inputs = super().prepare_inputs_for_generation(
  909. input_ids,
  910. past_key_values=past_key_values,
  911. inputs_embeds=inputs_embeds,
  912. attention_mask=attention_mask,
  913. cache_position=cache_position,
  914. logits_to_keep=logits_to_keep,
  915. **kwargs,
  916. )
  917. # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
  918. # Otherwise we need pixel values to be passed to model
  919. if cache_position[0] == 0:
  920. model_inputs["pixel_values"] = pixel_values
  921. return model_inputs
  922. def decode_image_tokens(self, image_tokens: torch.Tensor):
  923. """
  924. Decodes generated image tokens from language model to continuous pixel values
  925. with VQGAN module via upsampling.
  926. Args:
  927. image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
  928. The tensors corresponding to the input images.
  929. """
  930. decoded_image = self.model.vqmodel.decode(image_tokens)
  931. decoded_image = decoded_image.permute(0, 2, 3, 1)
  932. return decoded_image
  933. @torch.no_grad
  934. def generate(
  935. self,
  936. inputs: Optional[torch.Tensor] = None,
  937. attention_mask: Optional[torch.LongTensor] = None,
  938. logits_processor: Optional[LogitsProcessorList] = None,
  939. **kwargs,
  940. ):
  941. # 1. Handle generation config and model kwargs
  942. generation_config = kwargs.pop("generation_config", self.generation_config)
  943. generation_config = copy.deepcopy(generation_config)
  944. # Default to "text" generation if mode isn't provided
  945. generation_mode = kwargs.pop("generation_mode", "text")
  946. if generation_mode == "text":
  947. # Set guidance_scale=None to prevent running UnbatchedCFG processor.
  948. return super().generate(
  949. inputs=inputs,
  950. attention_mask=attention_mask,
  951. generation_config=generation_config,
  952. guidance_scale=None,
  953. **kwargs,
  954. )
  955. model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
  956. # Validate generation mode
  957. if generation_config.get_generation_mode() not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
  958. raise ValueError(
  959. "Got incompatible mode for Image Generation, should be one of greedy or sampling. "
  960. "Ensure that beam search is de-activated by setting `num_beams=1`."
  961. )
  962. # Validate the configuration and model kwargs
  963. generation_config.validate()
  964. self._validate_model_kwargs(model_kwargs.copy())
  965. # 2. Initialize logit processors
  966. logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
  967. # Set `use_cache=True` as we will be using input embeds for generation.
  968. model_kwargs["use_cache"] = True
  969. if generation_config.guidance_scale is None:
  970. logger.warning("`guidance_scale` is required for CFG but not provided. Setting to default value of 5.")
  971. generation_config.guidance_scale = 5
  972. model_kwargs["guidance_scale"] = generation_config.guidance_scale
  973. # 3. Prepare model inputs
  974. input_ids, model_input_name, model_kwargs = self._prepare_model_inputs(
  975. inputs, generation_config.bos_token_id, model_kwargs
  976. )
  977. dtype, device = input_ids.dtype, input_ids.device
  978. if len(input_ids.shape) != 2:
  979. raise ValueError(
  980. f"Expected input ids of shape (batch_size, seq_len), but got {input_ids.shape}"
  981. "Passing `inputs embeds` is not supported currently."
  982. )
  983. # Prepare special tokens which will be used generate internally.
  984. kwargs_has_attention_mask = attention_mask is not None
  985. self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device)
  986. # 4. Add CFG processor along with user passed logit processor.
  987. if generation_config.guidance_scale and generation_config.guidance_scale > 1:
  988. logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale))
  989. generation_config.guidance_scale = None # Reset to prevent processor duplication.
  990. # 5. Prepare logits processor
  991. logits_processor = self._get_logits_processor(
  992. generation_config=generation_config,
  993. input_ids_seq_length=input_ids.shape[1],
  994. encoder_input_ids=input_ids,
  995. prefix_allowed_tokens_fn=None,
  996. logits_processor=logits_processor,
  997. device=device,
  998. )
  999. # 6. Expand inputs for multiple image generations per prompt.
  1000. input_ids, model_kwargs = self._expand_inputs_for_generation(
  1001. input_ids=input_ids,
  1002. attention_mask=attention_mask,
  1003. expand_size=generation_config.num_return_sequences,
  1004. **model_kwargs,
  1005. )
  1006. # 7. Prepare input and model caches
  1007. num_image_tokens = self.model.vision_model.config.num_image_tokens
  1008. batch_size, seq_len = input_ids.shape
  1009. input_tokens = input_ids.repeat(2, 1) # Double batch size for conditional/unconditional logits
  1010. attention_mask = model_kwargs.pop("attention_mask", None)
  1011. attention_mask = attention_mask.repeat(2, 1)
  1012. model_kwargs["attention_mask"] = attention_mask
  1013. # Mask all the tokens that are neither BOS nor BOI with pad token in the unconditional logits.
  1014. mask = (input_tokens[batch_size:, :] != generation_config.bos_token_id) & (
  1015. input_tokens[batch_size:, :] != generation_config.generation_kwargs["boi_token_id"]
  1016. )
  1017. input_tokens[batch_size:, :].masked_fill_(mask, generation_config.pad_token_id)
  1018. inputs_embeds = self.get_input_embeddings()(input_tokens)
  1019. model_kwargs = self._get_initial_cache_position(seq_len, device, model_kwargs)
  1020. if model_kwargs.get("past_key_values", None) is None:
  1021. # Prepare cache if not provided.
  1022. model_kwargs["past_key_values"] = self._get_cache(
  1023. cache_implementation=generation_config.cache_implementation or "static",
  1024. # batch_size should account for both conditional/unconditional input; hence multiplied by 2.
  1025. batch_size=batch_size * 2,
  1026. # we should have at least a cache len of seq_len + num_image_tokens.
  1027. max_cache_len=max(generation_config.max_length, num_image_tokens + seq_len),
  1028. model_kwargs=model_kwargs,
  1029. )
  1030. # Placeholder for generated tokens.
  1031. generated_tokens = torch.zeros((batch_size, num_image_tokens), dtype=dtype, device=device)
  1032. # 8. init attention / hidden states / scores tuples
  1033. output_attentions = generation_config.output_attentions
  1034. output_hidden_states = generation_config.output_hidden_states
  1035. output_scores = generation_config.output_scores
  1036. output_logits = generation_config.output_logits
  1037. return_dict_in_generate = generation_config.return_dict_in_generate
  1038. raw_scores = () if (return_dict_in_generate and output_scores) else None
  1039. raw_logits = () if (return_dict_in_generate and output_logits) else None
  1040. decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
  1041. decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
  1042. for i in range(num_image_tokens):
  1043. model_inputs = self.prepare_inputs_for_generation(
  1044. inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs
  1045. )
  1046. model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device)
  1047. model_inputs["cache_position"] = model_inputs["cache_position"].to(inputs_embeds.device)
  1048. outputs = self.model.language_model(
  1049. **model_inputs,
  1050. output_attentions=output_attentions,
  1051. output_hidden_states=output_hidden_states,
  1052. )
  1053. # Update model_kwargs like cache_position for next generation.
  1054. model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
  1055. hidden_state = outputs.last_hidden_state[:, -1, :].clone()
  1056. # Generate scores using the generation head (Not using above defined LM Head)
  1057. scores = self.model.generation_head(hidden_state)
  1058. next_token_scores = logits_processor(input_ids, scores)
  1059. # Sample next token.
  1060. if generation_config.do_sample:
  1061. probs = torch.softmax(next_token_scores, dim=-1)
  1062. next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
  1063. else:
  1064. next_token = torch.argmax(next_token_scores, dim=-1)
  1065. generated_tokens[:, i] = next_token
  1066. # Prepare embeddings for the next step.
  1067. next_token = torch.cat([next_token, next_token])
  1068. next_token = next_token.unsqueeze(-1)
  1069. inputs_embeds = self.prepare_embeddings_for_image_generation(next_token)
  1070. if return_dict_in_generate:
  1071. if output_scores:
  1072. raw_scores += (scores,)
  1073. if output_logits:
  1074. raw_logits += (hidden_state.float(),)
  1075. if output_attentions:
  1076. decoder_attentions += outputs.attentions
  1077. if output_hidden_states:
  1078. decoder_hidden_states += outputs.hidden_states
  1079. if return_dict_in_generate:
  1080. return GenerateDecoderOnlyOutput(
  1081. sequences=generated_tokens,
  1082. scores=scores,
  1083. logits=raw_logits,
  1084. attentions=decoder_attentions,
  1085. hidden_states=decoder_hidden_states,
  1086. past_key_values=outputs.past_key_values,
  1087. )
  1088. else:
  1089. return generated_tokens
  1090. class JanusImageProcessor(BlipImageProcessor):
  1091. r"""
  1092. Constructs a JANUS image processor.
  1093. Args:
  1094. do_resize (`bool`, *optional*, defaults to `True`):
  1095. Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
  1096. `do_resize` parameter in the `preprocess` method.
  1097. size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
  1098. Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
  1099. method.
  1100. min_size (`int`, *optional*, defaults to 14):
  1101. The minimum allowed size for the resized image. Ensures that neither the height nor width
  1102. falls below this value after resizing.
  1103. resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
  1104. Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
  1105. overridden by the `resample` parameter in the `preprocess` method.
  1106. do_rescale (`bool`, *optional*, defaults to `True`):
  1107. Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
  1108. `do_rescale` parameter in the `preprocess` method.
  1109. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
  1110. Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
  1111. overridden by the `rescale_factor` parameter in the `preprocess` method.
  1112. do_normalize (`bool`, *optional*, defaults to `True`):
  1113. Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
  1114. method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
  1115. image_mean (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
  1116. Mean to use if normalizing the image. This is a float or list of floats the length of the number of
  1117. channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
  1118. overridden by the `image_mean` parameter in the `preprocess` method.
  1119. image_std (`float` or `list[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
  1120. Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
  1121. number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
  1122. Can be overridden by the `image_std` parameter in the `preprocess` method.
  1123. do_convert_rgb (`bool`, *optional*, defaults to `True`):
  1124. Whether to convert the image to RGB.
  1125. do_pad (`bool`, *optional*, defaults to `True`):
  1126. Whether to pad the image to square or not.
  1127. """
  1128. def __init__(
  1129. self,
  1130. do_resize: bool = True,
  1131. size: Optional[dict[str, int]] = None,
  1132. min_size: int = 14,
  1133. resample: PILImageResampling = PILImageResampling.BICUBIC,
  1134. do_rescale: bool = True,
  1135. rescale_factor: Union[int, float] = 1 / 255,
  1136. do_normalize: bool = True,
  1137. image_mean: Optional[Union[float, list[float]]] = None,
  1138. image_std: Optional[Union[float, list[float]]] = None,
  1139. do_convert_rgb: Optional[bool] = None,
  1140. do_pad: Optional[bool] = True,
  1141. **kwargs,
  1142. ):
  1143. super().__init__(**kwargs)
  1144. self.do_pad = do_pad
  1145. self.min_size = min_size
  1146. if image_mean is None:
  1147. self.background_color = (127, 127, 127)
  1148. else:
  1149. self.background_color = tuple(int(x * 255) for x in image_mean)
  1150. def pad_to_square(
  1151. self,
  1152. image: np.ndarray,
  1153. background_color: Union[int, tuple[int, int, int]] = 0,
  1154. data_format: Optional[Union[str, ChannelDimension]] = None,
  1155. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  1156. ) -> np.ndarray:
  1157. """
  1158. Pads an image to a square based on the longest edge.
  1159. Args:
  1160. image (`np.ndarray`):
  1161. The image to pad.
  1162. background_color (`int` or `tuple[int, int, int]`, *optional*, defaults to 0):
  1163. The color to use for the padding. Can be an integer for single channel or a
  1164. tuple of integers representing for multi-channel images. If passed as integer
  1165. in multi-channel mode, it will default to `0` in subsequent channels.
  1166. data_format (`str` or `ChannelDimension`, *optional*):
  1167. The channel dimension format for the output image. Can be one of:
  1168. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  1169. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  1170. If unset, will use same as the input image.
  1171. input_data_format (`str` or `ChannelDimension`, *optional*):
  1172. The channel dimension format for the input image. Can be one of:
  1173. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  1174. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  1175. Returns:
  1176. `np.ndarray`: The padded image.
  1177. """
  1178. height, width = get_image_size(image, input_data_format)
  1179. num_channels = image.shape[0] if input_data_format == ChannelDimension.FIRST else image.shape[-1]
  1180. if height == width:
  1181. image = (
  1182. to_channel_dimension_format(image, data_format, input_data_format)
  1183. if data_format is not None
  1184. else image
  1185. )
  1186. return image
  1187. max_dim = max(height, width)
  1188. # Ensure background_color is the correct shape
  1189. if isinstance(background_color, int):
  1190. background_color = [background_color]
  1191. elif len(background_color) != num_channels:
  1192. raise ValueError(
  1193. f"background_color must have no more than {num_channels} elements to match the number of channels"
  1194. )
  1195. if input_data_format == ChannelDimension.FIRST:
  1196. result = np.zeros((num_channels, max_dim, max_dim), dtype=image.dtype)
  1197. for i, color in enumerate(background_color):
  1198. result[i, :, :] = color
  1199. if width > height:
  1200. start = (max_dim - height) // 2
  1201. result[:, start : start + height, :] = image
  1202. else:
  1203. start = (max_dim - width) // 2
  1204. result[:, :, start : start + width] = image
  1205. else:
  1206. result = np.zeros((max_dim, max_dim, num_channels), dtype=image.dtype)
  1207. for i, color in enumerate(background_color):
  1208. result[:, :, i] = color
  1209. if width > height:
  1210. start = (max_dim - height) // 2
  1211. result[start : start + height, :, :] = image
  1212. else:
  1213. start = (max_dim - width) // 2
  1214. result[:, start : start + width, :] = image
  1215. return result
  1216. def resize(
  1217. self,
  1218. image: np.ndarray,
  1219. size: Union[dict[str, int], int],
  1220. resample: PILImageResampling = PILImageResampling.BICUBIC,
  1221. data_format: Optional[Union[str, ChannelDimension]] = None,
  1222. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  1223. **kwargs,
  1224. ) -> np.ndarray:
  1225. """
  1226. Resize an image to dynamically calculated size.
  1227. Args:
  1228. image (`np.ndarray`):
  1229. Image to resize.
  1230. size (`dict[str, int]` or `int`):
  1231. The size to resize the image to. If a dictionary, it should have the keys `"height"` and `"width"`.
  1232. resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
  1233. `PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
  1234. data_format (`ChannelDimension` or `str`, *optional*):
  1235. The channel dimension format for the output image. If unset, the channel dimension format of the input
  1236. image is used. Can be one of:
  1237. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  1238. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  1239. - `None`: will be inferred from input
  1240. input_data_format (`ChannelDimension` or `str`, *optional*):
  1241. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  1242. from the input image. Can be one of:
  1243. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  1244. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  1245. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  1246. Returns:
  1247. `np.ndarray`: The resized image.
  1248. """
  1249. if input_data_format is None:
  1250. input_data_format = infer_channel_dimension_format(image)
  1251. height, width = get_image_size(image, input_data_format)
  1252. max_size = max(height, width)
  1253. size = get_size_dict(size, default_to_square=True)
  1254. if size["height"] != size["width"]:
  1255. raise ValueError(
  1256. f"Output height and width must be the same. Got height={size['height']} and width={size['width']}"
  1257. )
  1258. size = size["height"]
  1259. delta = size / max_size
  1260. # Largest side becomes `size` and the other side is scaled according to the aspect ratio.
  1261. output_size_nonpadded = [
  1262. max(int(height * delta), self.min_size),
  1263. max(int(width * delta), self.min_size),
  1264. ]
  1265. image = resize(
  1266. image,
  1267. size=output_size_nonpadded,
  1268. resample=resample,
  1269. data_format=data_format,
  1270. input_data_format=input_data_format,
  1271. **kwargs,
  1272. )
  1273. return image
  1274. @filter_out_non_signature_kwargs()
  1275. def preprocess(
  1276. self,
  1277. images: ImageInput,
  1278. do_resize: Optional[bool] = None,
  1279. size: Optional[dict[str, int]] = None,
  1280. resample: Optional[PILImageResampling] = None,
  1281. do_rescale: Optional[bool] = None,
  1282. rescale_factor: Optional[float] = None,
  1283. do_normalize: Optional[bool] = None,
  1284. image_mean: Optional[Union[float, list[float]]] = None,
  1285. image_std: Optional[Union[float, list[float]]] = None,
  1286. return_tensors: Optional[Union[str, TensorType]] = None,
  1287. do_convert_rgb: Optional[bool] = None,
  1288. background_color: Optional[Union[int, tuple[int, int, int]]] = None,
  1289. do_pad: Optional[bool] = None,
  1290. data_format: ChannelDimension = ChannelDimension.FIRST,
  1291. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  1292. ) -> PIL.Image.Image:
  1293. """
  1294. Preprocess an image or batch of images.
  1295. Args:
  1296. images (`ImageInput`):
  1297. Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
  1298. passing in images with pixel values between 0 and 1, set `do_rescale=False`.
  1299. do_resize (`bool`, *optional*, defaults to `self.do_resize`):
  1300. Whether to resize the image.
  1301. size (`dict[str, int]`, *optional*, defaults to `self.size`):
  1302. Controls the size of the image after `resize`. The shortest edge of the image is resized to
  1303. `size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
  1304. is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
  1305. edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
  1306. resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
  1307. Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
  1308. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
  1309. Whether to rescale the image values between [0 - 1].
  1310. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
  1311. Rescale factor to rescale the image by if `do_rescale` is set to `True`.
  1312. do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
  1313. Whether to normalize the image.
  1314. image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`):
  1315. Image mean to normalize the image by if `do_normalize` is set to `True`.
  1316. image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`):
  1317. Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
  1318. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
  1319. Whether to convert the image to RGB.
  1320. background_color (`tuple[int, int, int]`):
  1321. The background color to use for the padding.
  1322. do_pad (`bool`, *optional*, defaults to `self.do_pad`):
  1323. Whether to pad the image to square or not.
  1324. return_tensors (`str` or `TensorType`, *optional*):
  1325. The type of tensors to return. Can be one of:
  1326. - Unset: Return a list of `np.ndarray`.
  1327. - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
  1328. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
  1329. - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
  1330. - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
  1331. data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
  1332. The channel dimension format for the output image. Can be one of:
  1333. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  1334. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  1335. - Unset: Use the channel dimension format of the input image.
  1336. input_data_format (`ChannelDimension` or `str`, *optional*):
  1337. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  1338. from the input image. Can be one of:
  1339. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  1340. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  1341. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  1342. """
  1343. do_resize = do_resize if do_resize is not None else self.do_resize
  1344. resample = resample if resample is not None else self.resample
  1345. do_rescale = do_rescale if do_rescale is not None else self.do_rescale
  1346. rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
  1347. do_normalize = do_normalize if do_normalize is not None else self.do_normalize
  1348. image_mean = image_mean if image_mean is not None else self.image_mean
  1349. image_std = image_std if image_std is not None else self.image_std
  1350. do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
  1351. do_pad = do_pad if do_pad is not None else self.do_pad
  1352. background_color = background_color if background_color is not None else self.background_color
  1353. size = size if size is not None else self.size
  1354. size = get_size_dict(size, default_to_square=False)
  1355. images = self.fetch_images(images)
  1356. images = make_flat_list_of_images(images)
  1357. if not valid_images(images):
  1358. raise ValueError(
  1359. "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
  1360. "torch.Tensor, tf.Tensor or jax.ndarray."
  1361. )
  1362. validate_preprocess_arguments(
  1363. do_rescale=do_rescale,
  1364. rescale_factor=rescale_factor,
  1365. do_normalize=do_normalize,
  1366. image_mean=image_mean,
  1367. image_std=image_std,
  1368. do_resize=do_resize,
  1369. size=size,
  1370. resample=resample,
  1371. )
  1372. # PIL RGBA images are converted to RGB
  1373. if do_convert_rgb:
  1374. images = [convert_to_rgb(image) for image in images]
  1375. # All transformations expect numpy arrays.
  1376. images = [to_numpy_array(image) for image in images]
  1377. if do_rescale and is_scaled_image(images[0]):
  1378. logger.warning_once(
  1379. "It looks like you are trying to rescale already rescaled images. If the input"
  1380. " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
  1381. )
  1382. if input_data_format is None:
  1383. # We assume that all images have the same channel dimension format.
  1384. input_data_format = infer_channel_dimension_format(images[0])
  1385. if do_resize:
  1386. images = [
  1387. self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
  1388. for image in images
  1389. ]
  1390. if do_pad:
  1391. # Expand and pad the images to obtain a square image of dimensions `size x size`
  1392. images = [
  1393. self.pad_to_square(
  1394. image=image,
  1395. background_color=background_color,
  1396. input_data_format=input_data_format,
  1397. )
  1398. for image in images
  1399. ]
  1400. if do_rescale:
  1401. images = [
  1402. self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
  1403. for image in images
  1404. ]
  1405. if do_normalize:
  1406. images = [
  1407. self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
  1408. for image in images
  1409. ]
  1410. images = [
  1411. to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
  1412. ]
  1413. encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
  1414. return encoded_outputs
  1415. def postprocess(
  1416. self,
  1417. images: ImageInput,
  1418. do_rescale: Optional[bool] = None,
  1419. rescale_factor: Optional[float] = None,
  1420. do_normalize: Optional[bool] = None,
  1421. image_mean: Optional[list[float]] = None,
  1422. image_std: Optional[list[float]] = None,
  1423. input_data_format: Optional[str] = None,
  1424. return_tensors: Optional[str] = None,
  1425. ):
  1426. """Applies post-processing to the decoded image tokens by reversing transformations applied during preprocessing."""
  1427. do_rescale = do_rescale if do_rescale is not None else self.do_rescale
  1428. rescale_factor = 1.0 / self.rescale_factor if rescale_factor is None else rescale_factor
  1429. do_normalize = do_normalize if do_normalize is not None else self.do_normalize
  1430. image_mean = image_mean if image_mean is not None else self.image_mean
  1431. image_std = image_std if image_std is not None else self.image_std
  1432. images = make_flat_list_of_images(images) # Ensures input is a list
  1433. if isinstance(images[0], PIL.Image.Image):
  1434. return images if len(images) > 1 else images[0]
  1435. if input_data_format is None:
  1436. input_data_format = infer_channel_dimension_format(images[0]) # Determine format dynamically
  1437. pixel_values = []
  1438. for image in images:
  1439. image = to_numpy_array(image) # Ensure NumPy format
  1440. if do_normalize:
  1441. image = self.unnormalize(
  1442. image=image, image_mean=image_mean, image_std=image_std, input_data_format=input_data_format
  1443. )
  1444. if do_rescale:
  1445. image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
  1446. image = image.clip(0, 255).astype(np.uint8)
  1447. if do_normalize and do_rescale and return_tensors == "PIL.Image.Image":
  1448. image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format)
  1449. image = PIL.Image.fromarray(image)
  1450. pixel_values.append(image)
  1451. data = {"pixel_values": pixel_values}
  1452. return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None
  1453. return BatchFeature(data=data, tensor_type=return_tensors)
  1454. def unnormalize(
  1455. self,
  1456. image: np.ndarray,
  1457. image_mean: Union[float, Iterable[float]],
  1458. image_std: Union[float, Iterable[float]],
  1459. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  1460. ) -> np.ndarray:
  1461. """
  1462. Unnormalizes `image` using the mean and standard deviation specified by `mean` and `std`.
  1463. image = (image * image_std) + image_mean
  1464. Args:
  1465. image (`torch.Tensor` of shape `(batch_size, num_channels, image_size, image_size)` or `(num_channels, image_size, image_size)`):
  1466. Batch of pixel values to postprocess.
  1467. image_mean (`float` or `Iterable[float]`):
  1468. The mean to use for unnormalization.
  1469. image_std (`float` or `Iterable[float]`):
  1470. The standard deviation to use for unnormalization.
  1471. input_data_format (`ChannelDimension` or `str`, *optional*):
  1472. The channel dimension format for the input image. If unset, the channel dimension format is inferred
  1473. from the input image. Can be one of:
  1474. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  1475. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  1476. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
  1477. """
  1478. num_channels = 3
  1479. if isinstance(image_mean, Iterable):
  1480. if len(image_mean) != num_channels:
  1481. raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(image_mean)}")
  1482. else:
  1483. image_mean = [image_mean] * num_channels
  1484. if isinstance(image_std, Iterable):
  1485. if len(image_std) != num_channels:
  1486. raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(image_std)}")
  1487. else:
  1488. image_std = [image_std] * num_channels
  1489. rev_image_mean = tuple(-mean / std for mean, std in zip(image_mean, image_std))
  1490. rev_image_std = tuple(1 / std for std in image_std)
  1491. image = self.normalize(
  1492. image=image, mean=rev_image_mean, std=rev_image_std, input_data_format=input_data_format
  1493. )
  1494. return image
  1495. __all__ = [
  1496. "JanusImageProcessor",
  1497. "JanusPreTrainedModel",
  1498. "JanusForConditionalGeneration",
  1499. "JanusModel",
  1500. "JanusVQVAE",
  1501. "JanusVisionModel",
  1502. "JanusVQVAEConfig",
  1503. "JanusVisionConfig",
  1504. "JanusConfig",
  1505. ]