modular_aria.py 69 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610
  1. # coding=utf-8
  2. # Copyright 2024 The Rhymes-AI Teams Authors and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from collections.abc import Iterable
  16. from typing import Optional, Union
  17. import numpy as np
  18. import torch
  19. from torch import nn
  20. from ...activations import ACT2FN
  21. from ...cache_utils import Cache
  22. from ...configuration_utils import PretrainedConfig
  23. from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_patch_output_size, select_best_resolution
  24. from ...image_transforms import PaddingMode, convert_to_rgb, pad, resize, to_channel_dimension_format
  25. from ...image_utils import (
  26. ChannelDimension,
  27. ImageInput,
  28. PILImageResampling,
  29. get_image_size,
  30. infer_channel_dimension_format,
  31. is_scaled_image,
  32. make_flat_list_of_images,
  33. to_numpy_array,
  34. valid_images,
  35. validate_preprocess_arguments,
  36. )
  37. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  38. from ...modeling_utils import PreTrainedModel
  39. from ...processing_utils import MultiModalData, ProcessingKwargs, ProcessorMixin, Unpack
  40. from ...tokenization_utils import PreTokenizedInput, TextInput
  41. from ...utils import TensorType, TransformersKwargs, auto_docstring, can_return_tuple, logging
  42. from ..auto import CONFIG_MAPPING, AutoConfig, AutoTokenizer
  43. from ..llama.configuration_llama import LlamaConfig
  44. from ..llama.modeling_llama import (
  45. LlamaAttention,
  46. LlamaDecoderLayer,
  47. LlamaForCausalLM,
  48. LlamaMLP,
  49. LlamaModel,
  50. LlamaPreTrainedModel,
  51. LlamaRMSNorm,
  52. )
  53. from ..llava.modeling_llava import (
  54. LlavaCausalLMOutputWithPast,
  55. LlavaForConditionalGeneration,
  56. LlavaModel,
  57. LlavaModelOutputWithPast,
  58. )
  59. from ..llava_next.image_processing_llava_next import divide_to_patches
  60. logger = logging.get_logger(__name__)
  61. def sequential_experts_gemm(token_states, expert_weights, tokens_per_expert):
  62. """
  63. Compute the matrix multiplication (GEMM) for each expert sequentially. This approach is computationally inefficient, especially when dealing with a large number of experts.
  64. Args:
  65. token_states (torch.Tensor): Input tensor of shape (num_tokens, in_features).
  66. expert_weights (torch.Tensor): Weight tensor of shape (num_experts, in_features, out_features).
  67. tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
  68. Returns:
  69. torch.Tensor: Output tensor of shape (num_tokens, out_features).
  70. """
  71. num_tokens = token_states.shape[0]
  72. out_features = expert_weights.shape[-1]
  73. output = torch.zeros(num_tokens, out_features, dtype=token_states.dtype, device=token_states.device)
  74. cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
  75. # Insert zero at the beginning for offset index's convenience
  76. zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
  77. cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
  78. for expert_num in range(expert_weights.shape[0]):
  79. start = cumsum_num_tokens[expert_num]
  80. end = cumsum_num_tokens[expert_num + 1]
  81. tokens = token_states[start:end]
  82. out = torch.matmul(tokens, expert_weights[expert_num])
  83. output[start:end] = out
  84. return output
  85. class AriaTextConfig(LlamaConfig):
  86. r"""
  87. This class handles the configuration for the text component of the Aria model.
  88. Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria
  89. [rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture.
  90. This class extends the LlamaConfig to include additional parameters specific to the Mixture of Experts (MoE) architecture.
  91. Args:
  92. vocab_size (`int`, *optional*, defaults to 32000):
  93. Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
  94. `inputs_ids` passed when calling [`LlamaModel`]
  95. hidden_size (`int`, *optional*, defaults to 4096):
  96. Dimension of the hidden representations.
  97. intermediate_size (`int`, *optional*, defaults to 4096):
  98. The size of the MLP representations.
  99. num_hidden_layers (`int`, *optional*, defaults to 32):
  100. Number of hidden layers in the Transformer decoder.
  101. num_attention_heads (`int`, *optional*, defaults to 32):
  102. Number of attention heads for each attention layer in the Transformer decoder.
  103. num_key_value_heads (`int`, *optional*):
  104. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  105. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  106. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  107. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  108. by meanpooling all the original heads within that group. For more details, check out [this
  109. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
  110. `num_attention_heads`.
  111. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
  112. The non-linear activation function (function or string) in the decoder.
  113. max_position_embeddings (`int`, *optional*, defaults to 2048):
  114. The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
  115. Llama 2 up to 4096, CodeLlama up to 16384.
  116. initializer_range (`float`, *optional*, defaults to 0.02):
  117. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  118. rms_norm_eps (`float`, *optional*, defaults to 1e-06):
  119. The epsilon used by the rms normalization layers.
  120. use_cache (`bool`, *optional*, defaults to `True`):
  121. Whether or not the model should return the last key/values attentions (not used by all models). Only
  122. relevant if `config.is_decoder=True`.
  123. pad_token_id (`int`, *optional*, defaults to 2):
  124. Padding token id.
  125. bos_token_id (`int`, *optional*, defaults to 1):
  126. Beginning of stream token id.
  127. eos_token_id (`int`, *optional*, defaults to 2):
  128. End of stream token id.
  129. pretraining_tp (`int`, *optional*, defaults to 1):
  130. Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
  131. document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to
  132. understand more about it. This value is necessary to ensure exact reproducibility of the pretraining
  133. results. Please refer to [this issue](https://github.com/pytorch/pytorch/issues/76232).
  134. tie_word_embeddings (`bool`, *optional*, defaults to `False`):
  135. Whether to tie weight embeddings
  136. rope_theta (`float`, *optional*, defaults to 10000.0):
  137. The base period of the RoPE embeddings.
  138. rope_scaling (`Dict`, *optional*):
  139. Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
  140. and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
  141. accordingly.
  142. Expected contents:
  143. `rope_type` (`str`):
  144. The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
  145. 'llama3'], with 'default' being the original RoPE implementation.
  146. `factor` (`float`, *optional*):
  147. Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
  148. most scaling types, a `factor` of x will enable the model to handle sequences of length x *
  149. original maximum pre-trained length.
  150. `original_max_position_embeddings` (`int`, *optional*):
  151. Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
  152. pretraining.
  153. `attention_factor` (`float`, *optional*):
  154. Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
  155. computation. If unspecified, it defaults to value recommended by the implementation, using the
  156. `factor` field to infer the suggested value.
  157. `beta_fast` (`float`, *optional*):
  158. Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
  159. ramp function. If unspecified, it defaults to 32.
  160. `beta_slow` (`float`, *optional*):
  161. Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
  162. ramp function. If unspecified, it defaults to 1.
  163. `short_factor` (`list[float]`, *optional*):
  164. Only used with 'longrope'. The scaling factor to be applied to short contexts (<
  165. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  166. size divided by the number of attention heads divided by 2
  167. `long_factor` (`list[float]`, *optional*):
  168. Only used with 'longrope'. The scaling factor to be applied to long contexts (<
  169. `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
  170. size divided by the number of attention heads divided by 2
  171. `low_freq_factor` (`float`, *optional*):
  172. Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
  173. `high_freq_factor` (`float`, *optional*):
  174. Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
  175. attention_bias (`bool`, *optional*, defaults to `False`):
  176. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  177. attention_dropout (`float`, *optional*, defaults to 0.0):
  178. The dropout ratio for the attention probabilities.
  179. mlp_bias (`bool`, *optional*, defaults to `False`):
  180. Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers.
  181. head_dim (`int`, *optional*):
  182. The attention head dimension. If None, it will default to hidden_size // num_heads
  183. moe_num_experts (`int`, *optional*, defaults to 8):
  184. The number of experts in the MoE layer.
  185. moe_topk (`int`, *optional*, defaults to 2):
  186. The number of top experts to route to for each token.
  187. moe_num_shared_experts (`int`, *optional*, defaults to 2):
  188. The number of shared experts.
  189. """
  190. model_type = "aria_text"
  191. base_config_key = "text_config"
  192. def __init__(
  193. self,
  194. intermediate_size: int = 4096,
  195. moe_num_experts: int = 8,
  196. moe_topk: int = 2,
  197. moe_num_shared_experts: int = 2,
  198. pad_token_id=2,
  199. **super_kwargs,
  200. ):
  201. super().__init__(pad_token_id=pad_token_id, **super_kwargs)
  202. self.intermediate_size = intermediate_size
  203. self.moe_num_experts = moe_num_experts
  204. self.moe_topk = moe_topk
  205. self.moe_num_shared_experts = moe_num_shared_experts
  206. class AriaConfig(PretrainedConfig):
  207. r"""
  208. This class handles the configuration for both vision and text components of the Aria model,
  209. as well as additional parameters for image token handling and projector mapping.
  210. Instantiating a configuration with the defaults will yield a similar configuration to that of the model of the Aria
  211. [rhymes-ai/Aria](https://huggingface.co/rhymes-ai/Aria) architecture.
  212. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  213. documentation from [`PretrainedConfig`] for more information.
  214. Args:
  215. vision_config (`AriaVisionConfig` or `dict`, *optional*):
  216. Configuration for the vision component.
  217. vision_feature_layer (`int`, *optional*, defaults to -1):
  218. The index of the layer to select the vision feature.
  219. text_config (`AriaTextConfig` or `dict`, *optional*):
  220. Configuration for the text component.
  221. projector_patch_to_query_dict (`dict`, *optional*):
  222. Mapping of patch sizes to query dimensions.
  223. image_token_index (`int`, *optional*, defaults to 9):
  224. Index used to represent image tokens.
  225. initializer_range (`float`, *optional*, defaults to 0.02):
  226. The standard deviation of the truncated normal initializer for initializing all weight matrices.
  227. Attributes:
  228. model_type (`str`):
  229. Type of the model, set to `"aria"`.
  230. image_token_index (`int`):
  231. Index used to represent image tokens.
  232. projector_patch_to_query_dict (`dict`):
  233. Mapping of patch sizes to query dimensions.
  234. vision_config (`AriaVisionConfig`):
  235. Configuration for the vision component.
  236. text_config (`AriaTextConfig`):
  237. Configuration for the text component.
  238. """
  239. model_type = "aria"
  240. attribute_map = {
  241. "image_token_id": "image_token_index",
  242. }
  243. sub_configs = {"text_config": AriaTextConfig, "vision_config": AutoConfig}
  244. def __init__(
  245. self,
  246. vision_config=None,
  247. vision_feature_layer: int = -1,
  248. text_config: AriaTextConfig = None,
  249. projector_patch_to_query_dict: Optional[dict] = None,
  250. image_token_index: int = 9,
  251. initializer_range: float = 0.02,
  252. **kwargs,
  253. ):
  254. self.image_token_index = image_token_index
  255. # Convert the keys and values of projector_patch_to_query_dict to integers
  256. # This ensures consistency even if they were provided as strings
  257. if projector_patch_to_query_dict is None:
  258. projector_patch_to_query_dict = {
  259. 1225: 128,
  260. 4900: 256,
  261. }
  262. self.projector_patch_to_query_dict = {int(k): int(v) for k, v in projector_patch_to_query_dict.items()}
  263. self.max_value_projector_patch_to_query_dict = max(self.projector_patch_to_query_dict.values())
  264. self.vision_feature_layer = vision_feature_layer
  265. if isinstance(vision_config, dict):
  266. vision_config["model_type"] = "idefics3_vision"
  267. vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
  268. elif vision_config is None:
  269. vision_config = CONFIG_MAPPING["idefics3_vision"]()
  270. self.vision_config = vision_config
  271. self.initializer_range = initializer_range
  272. if isinstance(text_config, dict) and "model_type" in text_config:
  273. text_config = AriaTextConfig(**text_config)
  274. elif text_config is None:
  275. text_config = AriaTextConfig()
  276. self.text_config = text_config
  277. super().__init__(**kwargs)
  278. class AriaTextRMSNorm(LlamaRMSNorm):
  279. pass
  280. class AriaProjectorMLP(nn.Module):
  281. """
  282. Feed-Forward Network module for the Aria Projector.
  283. Args:
  284. in_features (`int`):
  285. Input embedding dimension.
  286. hidden_features (`int`):
  287. Hidden dimension of the feed-forward network.
  288. output_dim (`int`):
  289. Output dimension.
  290. """
  291. def __init__(self, in_features, hidden_features, output_dim):
  292. super().__init__()
  293. self.linear_in = nn.Linear(in_features, hidden_features, bias=False)
  294. self.linear_out = nn.Linear(hidden_features, output_dim, bias=False)
  295. self.act = ACT2FN["gelu_new"]
  296. def forward(self, hidden_states):
  297. hidden_states = self.act(self.linear_in(hidden_states))
  298. hidden_states = self.linear_out(hidden_states)
  299. return hidden_states
  300. class AriaCrossAttention(nn.Module):
  301. """
  302. Aria Cross-Attention module.
  303. Args:
  304. config (`AriaConfig`):
  305. The configuration to use.
  306. """
  307. def __init__(self, config: AriaConfig, dropout_rate: float = 0):
  308. super().__init__()
  309. hidden_size = config.vision_config.hidden_size
  310. num_heads = config.vision_config.num_attention_heads
  311. self.num_heads = num_heads
  312. self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
  313. self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False)
  314. self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False)
  315. # Original code here: https://github.com/rhymes-ai/Aria/blob/719ff4e52b727443cba3793b0e27fe64e0244fe1/aria/model/projector.py#L48
  316. self.multihead_attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
  317. self.linear = nn.Linear(hidden_size, hidden_size)
  318. self.dropout = nn.Dropout(dropout_rate)
  319. self.layer_norm = nn.LayerNorm(hidden_size)
  320. self.layer_norm_kv = nn.LayerNorm(hidden_size)
  321. def forward(self, key_value_states, hidden_states, attn_mask=None):
  322. """
  323. Forward pass of the AriaCrossAttention module.
  324. Args:
  325. key_value_states (`torch.Tensor`):
  326. Input tensor for key and value.
  327. hidden_states (`torch.Tensor`):
  328. Input tensor for query.
  329. attn_mask (`torch.Tensor`, *optional*, defaults to None):
  330. Attention mask.
  331. Returns:
  332. torch.Tensor:
  333. Output tensor after cross-attention.
  334. """
  335. query = self.q_proj(self.layer_norm(hidden_states))
  336. key_value_states = self.layer_norm_kv(key_value_states)
  337. key = self.k_proj(key_value_states)
  338. value = self.v_proj(key_value_states)
  339. attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask)
  340. attn_output = self.dropout(self.linear(attn_output))
  341. return attn_output
  342. class AriaProjector(nn.Module):
  343. """
  344. Aria Projector module.
  345. This module projects vision features into the language model's embedding space, enabling interaction between vision and language components.
  346. Args:
  347. config (`AriaConfig`):
  348. Configuration object for the model.
  349. """
  350. def __init__(
  351. self,
  352. config: AriaConfig,
  353. ):
  354. super().__init__()
  355. self.patch_to_query_dict = config.projector_patch_to_query_dict
  356. self.in_features = config.vision_config.hidden_size
  357. self.num_heads = config.vision_config.num_attention_heads
  358. self.kv_dim = config.vision_config.hidden_size
  359. self.hidden_features = config.text_config.hidden_size
  360. self.output_dim = config.text_config.hidden_size
  361. self.query = nn.Parameter(torch.zeros(config.max_value_projector_patch_to_query_dict, self.in_features))
  362. self.cross_attn = AriaCrossAttention(config)
  363. self.layer_norm = nn.LayerNorm(self.in_features)
  364. self.feed_forward = AriaProjectorMLP(self.in_features, self.hidden_features, self.output_dim)
  365. def forward(self, key_value_states: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
  366. """
  367. Forward pass of the Projector module.
  368. Args:
  369. key_value_states (`torch.Tensor`):
  370. Input tensor of shape (batch_size, num_patches, kv_dim).
  371. attn_mask (`torch.Tensor`, *optional*, default is None):
  372. Attention mask.
  373. Returns:
  374. `torch.Tensor`: Output tensor of shape (batch_size, query_number, output_dim).
  375. """
  376. batch_size, num_patches = key_value_states.shape[0], key_value_states.shape[1]
  377. if num_patches not in self.patch_to_query_dict:
  378. raise KeyError(
  379. f"Number of patches {num_patches} not found in patch_to_query_dict amongst possible values {self.patch_to_query_dict.keys()}."
  380. )
  381. query_num = self.patch_to_query_dict[num_patches]
  382. queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
  383. if attn_mask is not None:
  384. attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
  385. attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)
  386. attention_out = self.cross_attn(key_value_states, queries, attn_mask=attn_mask)
  387. out = self.feed_forward(self.layer_norm(attention_out))
  388. return out
  389. class AriaImageProcessor(BaseImageProcessor):
  390. """
  391. A vision processor for the Aria model that handles image preprocessing.
  392. Initialize the AriaImageProcessor.
  393. Args:
  394. image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
  395. Mean values for normalization.
  396. image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
  397. Standard deviation values for normalization.
  398. max_image_size (`int`, *optional*, defaults to 980):
  399. Maximum image size.
  400. min_image_size (`int`, *optional*, defaults to 336):
  401. Minimum image size.
  402. split_resolutions (`list`, *optional*, defaults to a list of optimal,resolutions as tuples):
  403. The optimal resolutions for splitting the image.
  404. split_image (`bool`, *optional*, defaults to `False`):
  405. Whether to split the image.
  406. do_convert_rgb (`bool`, *optional*, defaults to `True`):
  407. Whether to convert the image to RGB.
  408. do_rescale (`bool`, *optional*, defaults to `True`):
  409. Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
  410. the `preprocess` method.
  411. rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
  412. Scale factor to use if rescaling the image. Can be overridden by `rescale_factor` in the `preprocess`
  413. method.
  414. do_normalize (`bool`, *optional*, defaults to `True`):
  415. Whether to normalize the image.
  416. resample (PILImageResampling, *optional*, defaults to `BICUBIC`):
  417. The resampling filter to use if resizing the image.
  418. """
  419. model_input_names = ["pixel_values", "pixel_mask", "num_crops"]
  420. def __init__(
  421. self,
  422. image_mean: Optional[list[float]] = None,
  423. image_std: Optional[list[float]] = None,
  424. max_image_size: int = 980,
  425. min_image_size: int = 336,
  426. split_resolutions: Optional[list[tuple[int, int]]] = None,
  427. split_image: Optional[bool] = False,
  428. do_convert_rgb: Optional[bool] = True,
  429. do_rescale: bool = True,
  430. rescale_factor: Union[int, float] = 1 / 255,
  431. do_normalize: Optional[bool] = True,
  432. resample: PILImageResampling = PILImageResampling.BICUBIC,
  433. **kwargs,
  434. ):
  435. super().__init__(**kwargs)
  436. if image_mean is None:
  437. image_mean = [0.5, 0.5, 0.5]
  438. if image_std is None:
  439. image_std = [0.5, 0.5, 0.5]
  440. self.max_image_size = max_image_size
  441. self.min_image_size = min_image_size
  442. self.image_mean = image_mean
  443. self.image_std = image_std
  444. self.split_image = split_image
  445. if split_resolutions is None:
  446. split_resolutions = [(1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (2, 4), (2, 3), (2, 2), (2, 1), (3, 1), (3, 2), (4, 1), (4, 2), (5, 1), (6, 1), (7, 1), (8, 1)] # fmt: skip
  447. split_resolutions = [(el[0] * 490, el[1] * 490) for el in split_resolutions]
  448. self.split_resolutions = split_resolutions
  449. self.do_convert_rgb = do_convert_rgb
  450. self.do_rescale = do_rescale
  451. self.rescale_factor = rescale_factor
  452. self.do_normalize = do_normalize
  453. self.resample = resample
  454. def preprocess(
  455. self,
  456. images: Union[ImageInput, list[ImageInput]],
  457. image_mean: Optional[Union[float, list[float]]] = None,
  458. image_std: Optional[Union[float, list[float]]] = None,
  459. max_image_size: Optional[int] = None,
  460. min_image_size: Optional[int] = None,
  461. split_image: Optional[bool] = None,
  462. do_convert_rgb: Optional[bool] = None,
  463. do_rescale: Optional[bool] = None,
  464. rescale_factor: Optional[float] = None,
  465. do_normalize: Optional[bool] = None,
  466. resample: Optional[PILImageResampling] = None,
  467. return_tensors: Optional[Union[str, TensorType]] = "pt",
  468. data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
  469. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  470. ):
  471. """
  472. Process a list of images.
  473. Args:
  474. images (ImageInput or list of ImageInput):
  475. The input image or a list of images.
  476. image_mean (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
  477. Mean values for normalization.
  478. image_std (`list`, *optional*, defaults to [0.5, 0.5, 0.5]):
  479. Standard deviation values for normalization.
  480. max_image_size (`int`, *optional*, defaults to `self.max_image_size` (980)):
  481. Maximum image size.
  482. min_image_size (`int`, *optional*, defaults to `self.min_image_size` (336)):
  483. Minimum image size.
  484. split_image (`bool`, *optional*, defaults to `self.split_image` (False)):
  485. Whether to split the image.
  486. do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb` (True)):
  487. Whether to convert the image to RGB.
  488. do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
  489. Whether to rescale the image.
  490. rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
  491. Rescale factor to rescale the image by if `do_rescale` is set to `True`.
  492. do_normalize (`bool`, *optional*, defaults to `self.do_normalize` (True)):
  493. Whether to normalize the image.
  494. resample (PILImageResampling, *optional*, defaults to `self.resample` (BICUBIC)):
  495. The resampling filter to use if resizing the image.
  496. return_tensors (`str` or `TensorType`, *optional*, defaults to "pt"):
  497. The type of tensor to return.
  498. data_format (`str` or `ChannelDimension`, *optional*):
  499. The channel dimension format for the output image. Can be one of:
  500. - `"channels_first"` or `ChannelDimension.FIRST`:
  501. image in (num_channels, height, width) format.
  502. - `"channels_last"` or `ChannelDimension.LAST`:
  503. image in (height, width, num_channels) format.
  504. If unset, will use same as the input image.
  505. input_data_format (`str` or `ChannelDimension`, *optional*):
  506. The channel dimension format for the input image. Can be one of:
  507. - `"channels_first"` or `ChannelDimension.FIRST`:
  508. image in (num_channels, height, width) format.
  509. - `"channels_last"` or `ChannelDimension.LAST`:
  510. image in (height, width, num_channels) format.
  511. If unset, will use the inferred format of the input image.
  512. Returns:
  513. BatchFeature:
  514. A BatchFeature object containing:
  515. - 'pixel_values':
  516. Tensor of processed image pixel values.
  517. - 'pixel_mask':
  518. Boolean pixel mask. This mask is a 2D tensor of shape (max_image_size, max_image_size) where:
  519. - True (1) values indicate pixels that belong to the original resized image.
  520. - False (0) values indicate pixels that are part of the padding.
  521. The mask helps distinguish between actual image content and padded areas in subsequent processing steps.
  522. - 'num_crops':
  523. The maximum number of crops across all images.
  524. """
  525. image_mean = image_mean if image_mean is not None else self.image_mean
  526. image_std = image_std if image_std is not None else self.image_std
  527. max_image_size = max_image_size if max_image_size is not None else self.max_image_size
  528. min_image_size = min_image_size if min_image_size is not None else self.min_image_size
  529. split_image = split_image if split_image is not None else self.split_image
  530. do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
  531. do_rescale = do_rescale if do_rescale is not None else self.do_rescale
  532. rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
  533. do_normalize = do_normalize if do_normalize is not None else self.do_normalize
  534. resample = resample if resample is not None else self.resample
  535. if max_image_size not in [490, 980]:
  536. raise ValueError("max_image_size must be either 490 or 980")
  537. images = self.fetch_images(images)
  538. images = make_flat_list_of_images(images)
  539. if not valid_images(images):
  540. raise ValueError(
  541. "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
  542. "torch.Tensor, tf.Tensor or jax.ndarray."
  543. )
  544. validate_preprocess_arguments(
  545. do_normalize=do_normalize,
  546. image_mean=image_mean,
  547. image_std=image_std,
  548. resample=resample,
  549. do_rescale=do_rescale,
  550. rescale_factor=rescale_factor,
  551. )
  552. if do_convert_rgb:
  553. images = [convert_to_rgb(image) for image in images]
  554. # All transformations expect numpy arrays.
  555. images = [to_numpy_array(image) for image in images]
  556. if do_rescale and is_scaled_image(images[0]):
  557. logger.warning_once(
  558. "It looks like you are trying to rescale already rescaled images. If the input"
  559. " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
  560. )
  561. if input_data_format is None:
  562. # We assume that all images have the same channel dimension format.
  563. input_data_format = infer_channel_dimension_format(images[0])
  564. pixel_values = []
  565. pixel_masks = []
  566. num_crops = None
  567. for image in images:
  568. if split_image:
  569. crop_images = self.get_image_patches(
  570. image,
  571. self.split_resolutions,
  572. max_image_size,
  573. resample,
  574. data_format=input_data_format,
  575. input_data_format=input_data_format,
  576. )
  577. else:
  578. crop_images = [image]
  579. if num_crops is None or len(crop_images) > num_crops:
  580. num_crops = len(crop_images)
  581. for crop_image in crop_images:
  582. # At this point the scale is the rescaling factor that would bring the image to max_size in its larger dimension
  583. h, w = get_image_size(crop_image)
  584. scale = max_image_size / max(h, w)
  585. if w >= h:
  586. new_size = (max(int(h * scale), min_image_size), max_image_size) # h, w
  587. else:
  588. new_size = (max_image_size, max(int(w * scale), min_image_size)) # h, w
  589. crop_image_resized = resize(
  590. crop_image,
  591. new_size,
  592. resample=resample,
  593. data_format=input_data_format,
  594. input_data_format=input_data_format,
  595. )
  596. padding_bottom, padding_right = max_image_size - new_size[0], max_image_size - new_size[1]
  597. crop_image_padded = pad(
  598. crop_image_resized,
  599. ((0, padding_bottom), (0, padding_right)),
  600. data_format=input_data_format,
  601. input_data_format=input_data_format,
  602. )
  603. # Create a pixel mask
  604. pixel_mask = np.zeros((max_image_size, max_image_size), dtype=bool)
  605. pixel_mask[: new_size[0], : new_size[1]] = 1
  606. pixel_masks.append(pixel_mask)
  607. if do_rescale:
  608. crop_image_padded = self.rescale(
  609. image=crop_image_padded, scale=rescale_factor, input_data_format=input_data_format
  610. )
  611. if do_normalize:
  612. crop_image_padded = self.normalize(
  613. crop_image_padded,
  614. self.image_mean,
  615. self.image_std,
  616. data_format=input_data_format,
  617. input_data_format=input_data_format,
  618. )
  619. crop_image_padded = (
  620. to_channel_dimension_format(crop_image_padded, data_format, input_data_format)
  621. if data_format is not None
  622. else crop_image_padded
  623. )
  624. pixel_values.append(crop_image_padded)
  625. return BatchFeature(
  626. data={
  627. "pixel_values": np.stack(pixel_values, axis=0),
  628. "pixel_mask": np.stack(pixel_masks, axis=0),
  629. "num_crops": num_crops,
  630. },
  631. tensor_type=return_tensors,
  632. )
  633. def _resize_for_patching(
  634. self, image: np.ndarray, target_resolution: tuple, resample, input_data_format: ChannelDimension
  635. ) -> np.ndarray:
  636. """
  637. Resizes an image to a target resolution while maintaining aspect ratio.
  638. Args:
  639. image (np.ndarray):
  640. The input image.
  641. target_resolution (tuple):
  642. The target resolution (height, width) of the image.
  643. resample (`PILImageResampling`):
  644. Resampling filter to use if resizing the image.
  645. input_data_format (`ChannelDimension` or `str`):
  646. The channel dimension format of the input image.
  647. Returns:
  648. np.ndarray: The resized and padded image.
  649. """
  650. new_height, new_width = get_patch_output_size(image, target_resolution, input_data_format)
  651. # Resize the image
  652. resized_image = resize(image, (new_height, new_width), resample=resample, input_data_format=input_data_format)
  653. return resized_image
  654. def _get_padding_size(self, original_resolution: tuple, target_resolution: tuple):
  655. original_height, original_width = original_resolution
  656. target_height, target_width = target_resolution
  657. paste_x, r_x = divmod(target_width - original_width, 2)
  658. paste_y, r_y = divmod(target_height - original_height, 2)
  659. return (paste_y, paste_y + r_y), (paste_x, paste_x + r_x)
  660. def _pad_for_patching(
  661. self, image: np.ndarray, target_resolution: tuple, input_data_format: ChannelDimension
  662. ) -> np.ndarray:
  663. """
  664. Pad an image to a target resolution while maintaining aspect ratio.
  665. """
  666. new_resolution = get_patch_output_size(image, target_resolution, input_data_format)
  667. padding = self._get_padding_size(new_resolution, target_resolution)
  668. padded_image = self.pad(image, padding=padding)
  669. return padded_image
  670. def pad(
  671. self,
  672. image: np.ndarray,
  673. padding: Union[int, tuple[int, int], Iterable[tuple[int, int]]],
  674. mode: PaddingMode = PaddingMode.CONSTANT,
  675. constant_values: Union[float, Iterable[float]] = 0.0,
  676. data_format: Optional[Union[str, ChannelDimension]] = None,
  677. input_data_format: Optional[Union[str, ChannelDimension]] = None,
  678. ) -> np.ndarray:
  679. """
  680. Pads the `image` with the specified `padding` and `mode`. Padding can be in the (`height`, `width`)
  681. dimension of in the (`num_patches`) dimension. In the second case an iterable if tuples is expected
  682. as input.
  683. Args:
  684. image (`np.ndarray`):
  685. The image to pad.
  686. padding (`int` or `tuple[int, int]` or `Iterable[tuple[int, int]]`):
  687. Padding to apply to the edges of the height, width axes. Can be one of three formats:
  688. - `((before_height, after_height), (before_width, after_width))` unique pad widths for each axis.
  689. - `((before, after),)` yields same before and after pad for height and width.
  690. - `(pad,)` or int is a shortcut for before = after = pad width for all axes.
  691. mode (`PaddingMode`):
  692. The padding mode to use. Can be one of:
  693. - `"constant"`: pads with a constant value.
  694. - `"reflect"`: pads with the reflection of the vector mirrored on the first and last values of the
  695. vector along each axis.
  696. - `"replicate"`: pads with the replication of the last value on the edge of the array along each axis.
  697. - `"symmetric"`: pads with the reflection of the vector mirrored along the edge of the array.
  698. constant_values (`float` or `Iterable[float]`, *optional*):
  699. The value to use for the padding if `mode` is `"constant"`.
  700. data_format (`str` or `ChannelDimension`, *optional*):
  701. The channel dimension format for the output image. Can be one of:
  702. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  703. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  704. If unset, will use same as the input image.
  705. input_data_format (`str` or `ChannelDimension`, *optional*):
  706. The channel dimension format for the input image. Can be one of:
  707. - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
  708. - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
  709. If unset, will use the inferred format of the input image.
  710. Returns:
  711. `np.ndarray`: The padded image.
  712. """
  713. # call the general `pad` if padding on `height/width`, otherwise it's the `num_patched` dim
  714. if isinstance(padding, int) or len(padding) != 4:
  715. return pad(image, padding, mode, constant_values, data_format, input_data_format)
  716. if input_data_format is None:
  717. input_data_format = infer_channel_dimension_format(image)
  718. padding_mode_mapping = {
  719. PaddingMode.CONSTANT: "constant",
  720. PaddingMode.REFLECT: "reflect",
  721. PaddingMode.REPLICATE: "edge",
  722. PaddingMode.SYMMETRIC: "symmetric",
  723. }
  724. image = np.pad(image, padding, mode=padding_mode_mapping[mode], constant_values=constant_values)
  725. image = (
  726. to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image
  727. )
  728. return image
  729. def get_image_patches(
  730. self,
  731. image: np.ndarray,
  732. grid_pinpoints: list[tuple[int, int]],
  733. patch_size: int,
  734. resample: PILImageResampling,
  735. data_format: ChannelDimension,
  736. input_data_format: ChannelDimension,
  737. ) -> list[np.ndarray]:
  738. """
  739. Process an image with variable resolutions by dividing it into patches.
  740. Args:
  741. image (`np.ndarray`):
  742. The input image to be processed.
  743. grid_pinpoints (list[tuple[int, int]]):
  744. A list of possible resolutions as tuples.
  745. patch_size (`int`):
  746. Size of the patches to divide the image into.
  747. resample (`PILImageResampling`):
  748. Resampling filter to use if resizing the image.
  749. data_format (`ChannelDimension` or `str`):
  750. The channel dimension format for the output image.
  751. input_data_format (`ChannelDimension` or `str`):
  752. The channel dimension format of the input image.
  753. Returns:
  754. `list[np.ndarray]`: A list of NumPy arrays containing the processed image patches.
  755. """
  756. if not isinstance(grid_pinpoints, list):
  757. raise TypeError("grid_pinpoints must be a list of possible resolutions.")
  758. possible_resolutions = grid_pinpoints
  759. image_size = get_image_size(image, channel_dim=input_data_format)
  760. best_resolution = select_best_resolution(image_size, possible_resolutions)
  761. resized_image = self._resize_for_patching(
  762. image, best_resolution, resample=resample, input_data_format=input_data_format
  763. )
  764. padded_image = self._pad_for_patching(resized_image, best_resolution, input_data_format=input_data_format)
  765. patches = divide_to_patches(padded_image, patch_size=patch_size, input_data_format=input_data_format)
  766. # make sure that all patches are in the input data format
  767. patches = [
  768. to_channel_dimension_format(patch, channel_dim=data_format, input_channel_dim=input_data_format)
  769. for patch in patches
  770. ]
  771. return patches
  772. def get_number_of_image_patches(self, height: int, width: int, images_kwargs=None):
  773. """
  774. A utility that returns number of image patches for a given image size.
  775. Args:
  776. height (`int`):
  777. Height of the input image.
  778. width (`int`):
  779. Width of the input image.
  780. images_kwargs (`dict`, *optional*)
  781. Any kwargs to override defaults of the image processor.
  782. Returns:
  783. `int`: Number of patches per image.
  784. """
  785. split_image = images_kwargs.get("split_image", self.split_image)
  786. max_image_size = images_kwargs.get("max_image_size", self.max_image_size)
  787. resized_height, resized_width = select_best_resolution((height, width), self.split_resolutions)
  788. num_patches = 1 if not split_image else resized_height // max_image_size * resized_width // max_image_size
  789. return num_patches
  790. class AriaProcessorKwargs(ProcessingKwargs, total=False):
  791. _defaults = {
  792. "text_kwargs": {
  793. "padding": False,
  794. "return_mm_token_type_ids": False,
  795. },
  796. "images_kwargs": {
  797. "max_image_size": 980,
  798. "split_image": False,
  799. },
  800. "return_tensors": TensorType.PYTORCH,
  801. }
  802. class AriaProcessor(ProcessorMixin):
  803. """
  804. AriaProcessor is a processor for the Aria model which wraps the Aria image preprocessor and the LLama slow tokenizer.
  805. Args:
  806. image_processor (`AriaImageProcessor`, *optional*):
  807. The AriaImageProcessor to use for image preprocessing.
  808. tokenizer (`PreTrainedTokenizerBase`, *optional*):
  809. An instance of [`PreTrainedTokenizerBase`]. This should correspond with the model's text model. The tokenizer is a required input.
  810. chat_template (`str`, *optional*):
  811. A Jinja template which will be used to convert lists of messages in a chat into a tokenizable string.
  812. size_conversion (`Dict`, *optional*):
  813. A dictionary indicating size conversions for images.
  814. """
  815. attributes = ["image_processor", "tokenizer"]
  816. image_processor_class = "AriaImageProcessor"
  817. tokenizer_class = "AutoTokenizer"
  818. def __init__(
  819. self,
  820. image_processor=None,
  821. tokenizer: Union[AutoTokenizer, str] = None,
  822. chat_template: Optional[str] = None,
  823. size_conversion: Optional[dict[Union[float, int], int]] = None,
  824. ):
  825. if size_conversion is None:
  826. size_conversion = {490: 128, 980: 256}
  827. self.size_conversion = {int(k): v for k, v in size_conversion.items()}
  828. self.image_token = tokenizer.image_token
  829. self.image_token_id = tokenizer.image_token_id
  830. if tokenizer is not None and tokenizer.pad_token is None:
  831. tokenizer.pad_token = tokenizer.unk_token
  832. super().__init__(image_processor, tokenizer, chat_template=chat_template)
  833. def __call__(
  834. self,
  835. text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]],
  836. images: Optional[ImageInput] = None,
  837. audio=None,
  838. videos=None,
  839. **kwargs: Unpack[AriaProcessorKwargs],
  840. ) -> BatchFeature:
  841. """
  842. Main method to prepare for the model one or several sequences(s) and image(s).
  843. Args:
  844. text (`TextInput`, `PreTokenizedInput`, `list[TextInput]`, `list[PreTokenizedInput]`):
  845. The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
  846. (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
  847. `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
  848. images (`ImageInput`):
  849. The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
  850. tensor. Both channels-first and channels-last formats are supported.
  851. Returns:
  852. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  853. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  854. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  855. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  856. `None`).
  857. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  858. - **pixel_mask** -- Pixel mask to be fed to a model. Returned when `images` is not `None`.
  859. """
  860. output_kwargs = self._merge_kwargs(
  861. AriaProcessorKwargs,
  862. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  863. **kwargs,
  864. )
  865. if isinstance(text, str):
  866. text = [text]
  867. elif not isinstance(text, list) and not isinstance(text[0], str):
  868. raise TypeError("Invalid input text. Please provide a string, or a list of strings")
  869. if images is not None:
  870. image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
  871. # expand the image_token according to the num_crops and tokens per image
  872. tokens_per_image = self.size_conversion[image_inputs.pixel_values.shape[2]]
  873. prompt_strings = []
  874. num_crops = image_inputs.pop("num_crops") * tokens_per_image
  875. for sample in text:
  876. sample = sample.replace(self.tokenizer.image_token, self.tokenizer.image_token * num_crops)
  877. prompt_strings.append(sample)
  878. else:
  879. image_inputs = {}
  880. prompt_strings = text
  881. return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
  882. return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
  883. text_inputs = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"], return_tensors=None)
  884. self._check_special_mm_tokens(prompt_strings, text_inputs, modalities=["image"])
  885. if return_mm_token_type_ids:
  886. array_ids = np.array(text_inputs["input_ids"])
  887. mm_token_type_ids = np.zeros_like(text_inputs["input_ids"])
  888. mm_token_type_ids[array_ids == self.image_token_id] = 1
  889. text_inputs["mm_token_type_ids"] = mm_token_type_ids.tolist()
  890. return BatchFeature(data={**text_inputs, **image_inputs}, tensor_type=return_tensors)
  891. def _get_num_multimodal_tokens(self, image_sizes=None, **kwargs):
  892. """
  893. Computes the number of placeholder tokens needed for multimodal inputs with the given sizes.
  894. Args:
  895. image_sizes (`list[list[int]]`, *optional*):
  896. The input sizes formatted as (height, width) per each image.
  897. Returns:
  898. `MultiModalData`: A `MultiModalData` object holding number of tokens per each of the provided
  899. input modalities, along with other useful data.
  900. """
  901. vision_data = {}
  902. if image_sizes is not None:
  903. images_kwargs = AriaProcessorKwargs._defaults.get("images_kwargs", {})
  904. images_kwargs.update(kwargs)
  905. max_size = images_kwargs.get("max_image_size", None) or self.image_processor.max_image_size
  906. num_image_patches = [
  907. self.image_processor.get_number_of_image_patches(*image_size, images_kwargs)
  908. for image_size in image_sizes
  909. ]
  910. num_image_tokens = [self.size_conversion[max_size] * num_patches for num_patches in num_image_patches]
  911. vision_data.update({"num_image_tokens": num_image_tokens, "num_image_patches": num_image_patches})
  912. return MultiModalData(**vision_data)
  913. @property
  914. def model_input_names(self):
  915. tokenizer_input_names = self.tokenizer.model_input_names
  916. image_processor_input_names = self.image_processor.model_input_names
  917. # Remove `num_crops`, it is popped and used only when processing. Make a copy of list when removing
  918. # otherwise `self.image_processor.model_input_names` is also modified
  919. image_processor_input_names = [name for name in image_processor_input_names if name != "num_crops"]
  920. return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
  921. class AriaSharedExpertsMLP(LlamaMLP):
  922. """
  923. Shared Expert MLP for shared experts.
  924. Unlike routed experts, shared experts process all tokens without routing.
  925. This class reconfigures the intermediate size in comparison to the LlamaMLP.
  926. Args:
  927. config (`AriaTextConfig`): Configuration object for the Aria language model.
  928. """
  929. def __init__(self, config: AriaTextConfig):
  930. super().__init__(config)
  931. self.intermediate_size = config.intermediate_size * config.moe_num_shared_experts
  932. class AriaGroupedExpertsGemm(nn.Module):
  933. """
  934. Grouped GEMM (General Matrix Multiplication) module for efficient expert computation.
  935. This module utilizes the grouped_gemm library (https://github.com/fanshiqing/grouped_gemm)
  936. for optimized performance. If the grouped_gemm library is not installed, it gracefully
  937. falls back to a sequential GEMM implementation, which may be slower but ensures
  938. functionality.
  939. Args:
  940. in_features (`int`):
  941. Number of input features.
  942. out_features (`int`):
  943. Number of output features.
  944. groups (`int`):
  945. Number of expert groups.
  946. """
  947. def __init__(self, in_features, out_features, groups):
  948. super().__init__()
  949. self.in_features = in_features
  950. self.out_features = out_features
  951. self.groups = groups
  952. self.weight = nn.Parameter(torch.empty(groups, in_features, out_features))
  953. def forward(self, input, tokens_per_expert):
  954. """
  955. Perform grouped matrix multiplication.
  956. Args:
  957. input (`torch.Tensor`):
  958. Input tensor of shape (num_tokens, in_features).
  959. tokens_per_expert (`torch.Tensor`):
  960. Number of tokens assigned to each expert.
  961. Returns:
  962. torch.Tensor: Output tensor of shape (num_tokens, out_features).
  963. """
  964. return sequential_experts_gemm(
  965. input,
  966. self.weight,
  967. tokens_per_expert.cpu(),
  968. )
  969. class AriaGroupedExpertsMLP(nn.Module):
  970. """
  971. Grouped MLP module for Mixture of Experts.
  972. Args:
  973. config (`AriaTextConfig`):
  974. Configuration object for the model.
  975. """
  976. def __init__(self, config: AriaTextConfig) -> None:
  977. super().__init__()
  978. self.config = config
  979. self.fc1 = AriaGroupedExpertsGemm(config.hidden_size, config.intermediate_size * 2, config.moe_num_experts)
  980. self.fc2 = AriaGroupedExpertsGemm(config.intermediate_size, config.hidden_size, config.moe_num_experts)
  981. def forward(self, permuted_tokens, tokens_per_expert):
  982. """
  983. Forward pass of the Grouped MLP.
  984. Args:
  985. permuted_tokens (torch.Tensor): Permuted input tokens.
  986. tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert.
  987. Returns:
  988. torch.Tensor: Output tensor after passing through the MLP.
  989. """
  990. fc1_output = self.fc1(permuted_tokens, tokens_per_expert)
  991. projection, gate = torch.chunk(fc1_output, 2, dim=-1)
  992. fc1_output = nn.functional.silu(projection) * gate
  993. fc2_output = self.fc2(fc1_output, tokens_per_expert)
  994. return fc2_output
  995. # Token permutation adapted from https://github.com/NVIDIA/Megatron-LM/blob/54f1f78529cbc2b9cddad313e7f9d96ac0420a27/megatron/core/transformer/moe/token_dispatcher.py#L291-L587
  996. class AriaTextMoELayer(nn.Module):
  997. """
  998. Aria Text Mixture of Experts (MoE) Layer.
  999. This layer applies a gating mechanism to route input tokens to different experts.
  1000. Args:
  1001. config (`AriaTextConfig`):
  1002. Configuration object for the text component of the model.
  1003. """
  1004. def __init__(self, config: AriaTextConfig):
  1005. super().__init__()
  1006. self.router = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False)
  1007. self.experts = AriaGroupedExpertsMLP(config)
  1008. self.shared_experts = AriaSharedExpertsMLP(config)
  1009. self.config = config
  1010. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  1011. """
  1012. Forward pass of the MoE Layer.
  1013. Args:
  1014. hidden_states (`torch.Tensor`):
  1015. Input tensor of shape (batch_size, sequence_length, hidden_size).
  1016. Returns:
  1017. torch.Tensor: Output tensor after passing through the MoE layer.
  1018. Process:
  1019. 1. Route tokens to experts using the router.
  1020. 2. Permute tokens based on routing decisions.
  1021. 3. Process tokens through experts.
  1022. 4. Unpermute and combine expert outputs.
  1023. 5. Add shared expert output to the final result.
  1024. """
  1025. original_shape = hidden_states.shape
  1026. hidden_states = hidden_states.view(-1, hidden_states.size(-1))
  1027. # Top K Routing
  1028. logits = self.router(hidden_states)
  1029. top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
  1030. scores = nn.functional.softmax(top_logits, dim=-1)
  1031. original_dtype = top_indices.dtype
  1032. tokens_per_expert = torch.histc(
  1033. top_indices.flatten().to(torch.float32),
  1034. bins=self.config.moe_num_experts,
  1035. min=0,
  1036. max=self.config.moe_num_experts - 1,
  1037. ).to(original_dtype)
  1038. indices = top_indices
  1039. # Token permutation
  1040. flatten_indices = indices.view(-1)
  1041. sorted_indices = torch.argsort(flatten_indices)
  1042. permuted_tokens = hidden_states.index_select(0, sorted_indices // self.config.moe_topk)
  1043. # Process through experts
  1044. expert_output = self.experts(permuted_tokens, tokens_per_expert)
  1045. # Token unpermutation
  1046. unpermuted_tokens = torch.zeros(
  1047. (scores.shape[0] * self.config.moe_topk, expert_output.size(1)),
  1048. dtype=expert_output.dtype,
  1049. device=expert_output.device,
  1050. )
  1051. unpermuted_tokens.index_copy_(0, sorted_indices, expert_output)
  1052. unpermuted_tokens = unpermuted_tokens.view(-1, self.config.moe_topk, expert_output.size(1))
  1053. output = (unpermuted_tokens * scores.unsqueeze(-1)).sum(dim=1).view(original_shape)
  1054. # Add shared expert output
  1055. shared_expert_output = self.shared_experts(hidden_states.view(original_shape))
  1056. return output + shared_expert_output
  1057. class AriaTextAttention(LlamaAttention):
  1058. """Multi-headed attention from 'Attention Is All You Need' paper"""
  1059. pass
  1060. class AriaTextDecoderLayer(LlamaDecoderLayer):
  1061. """
  1062. Aria Text Decoder Layer.
  1063. This class defines a single decoder layer in the language model, incorporating self-attention and Mixture of Experts (MoE) feed-forward network.
  1064. Args:
  1065. config (`AriaTextConfig`):
  1066. Configuration object for the text component of the model.
  1067. layer_idx (`int`):
  1068. Index of the layer.
  1069. """
  1070. def __init__(self, config: AriaTextConfig, layer_idx: int):
  1071. super().__init__(config, layer_idx)
  1072. self.mlp = AriaTextMoELayer(config)
  1073. @auto_docstring
  1074. class AriaTextPreTrainedModel(PreTrainedModel):
  1075. config: AriaTextConfig
  1076. base_model_prefix = "model"
  1077. _no_split_modules = ["AriaTextDecoderLayer", "AriaGroupedExpertsGemm"]
  1078. supports_gradient_checkpointing = True
  1079. _skip_keys_device_placement = "past_key_values"
  1080. _supports_flash_attn = True
  1081. _supports_sdpa = True
  1082. _supports_attention_backend = True
  1083. _can_record_outputs = {
  1084. "hidden_states": AriaTextDecoderLayer,
  1085. "attentions": AriaTextAttention,
  1086. }
  1087. def _init_weights(self, module):
  1088. super()._init_weights(module)
  1089. if isinstance(module, AriaGroupedExpertsGemm):
  1090. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  1091. class AriaPreTrainedModel(LlamaPreTrainedModel):
  1092. config: AriaConfig
  1093. base_model_prefix = ""
  1094. _can_compile_fullgraph = False # MoE models don't work with torch.compile (dynamic slicing)
  1095. _supports_attention_backend = True
  1096. def _init_weights(self, module):
  1097. PreTrainedModel._init_weights(self, module)
  1098. if isinstance(module, AriaProjector):
  1099. nn.init.trunc_normal_(module.query, std=self.config.initializer_range)
  1100. class AriaTextModel(LlamaModel):
  1101. def __init__(self, config: AriaTextConfig):
  1102. super().__init__(config)
  1103. self.layers = nn.ModuleList(
  1104. [AriaTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  1105. )
  1106. self.gradient_checkpointing = False
  1107. self.post_init()
  1108. class AriaTextForCausalLM(AriaTextPreTrainedModel, LlamaForCausalLM):
  1109. _tied_weights_keys = ["lm_head.weight"]
  1110. def __init__(self, config: AriaTextConfig):
  1111. super().__init__(config)
  1112. self.model = AriaTextModel(config)
  1113. self.vocab_size = config.vocab_size
  1114. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1115. # Initialize weights and apply final processing
  1116. self.post_init()
  1117. @auto_docstring
  1118. def forward(self, **super_kwargs):
  1119. super().forward(self, **super_kwargs)
  1120. class AriaCausalLMOutputWithPast(LlavaCausalLMOutputWithPast):
  1121. pass
  1122. class AriaModelOutputWithPast(LlavaModelOutputWithPast):
  1123. pass
  1124. class AriaModel(LlavaModel):
  1125. def __init__(self, config: AriaConfig):
  1126. super().__init__(config)
  1127. self.multi_modal_projector = AriaProjector(config)
  1128. def _create_patch_attention_mask(self, pixel_mask):
  1129. if pixel_mask is None:
  1130. return None
  1131. patches_subgrid = pixel_mask.unfold(
  1132. dimension=1,
  1133. size=self.vision_tower.config.patch_size,
  1134. step=self.vision_tower.config.patch_size,
  1135. )
  1136. patches_subgrid = patches_subgrid.unfold(
  1137. dimension=2,
  1138. size=self.vision_tower.config.patch_size,
  1139. step=self.vision_tower.config.patch_size,
  1140. )
  1141. return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
  1142. def get_image_features(
  1143. self,
  1144. pixel_values: torch.FloatTensor,
  1145. pixel_mask: Optional[torch.FloatTensor] = None,
  1146. vision_feature_layer: int = -1,
  1147. ):
  1148. """
  1149. Obtains image last hidden states from the vision tower and apply multimodal projection.
  1150. Args:
  1151. pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
  1152. The tensors corresponding to the input images.
  1153. pixel_mask (`torch.FloatTensor]`, *optional*):
  1154. The tensors corresponding to the input image mask.
  1155. vision_feature_layer (`Union[int, list[int]]`, *optional*):
  1156. The index of the layer to select the vision feature. If multiple indices are provided,
  1157. the vision feature of the corresponding indices will be concatenated to form the
  1158. vision features.
  1159. Returns:
  1160. image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
  1161. """
  1162. vision_feature_layer = (
  1163. vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
  1164. )
  1165. patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
  1166. image_outputs = self.vision_tower(
  1167. pixel_values, patch_attention_mask=patch_attention_mask, output_hidden_states=True
  1168. )
  1169. image_attn_mask = None
  1170. if patch_attention_mask is not None:
  1171. flattened_mask = patch_attention_mask.flatten(1)
  1172. image_attn_mask = torch.logical_not(flattened_mask)
  1173. selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
  1174. image_features = self.multi_modal_projector(selected_image_feature, attn_mask=image_attn_mask)
  1175. return image_features
  1176. def forward(
  1177. self,
  1178. input_ids: Optional[torch.LongTensor] = None,
  1179. pixel_values: Optional[torch.FloatTensor] = None,
  1180. pixel_mask: Optional[torch.LongTensor] = None,
  1181. attention_mask: Optional[torch.Tensor] = None,
  1182. position_ids: Optional[torch.LongTensor] = None,
  1183. past_key_values: Optional[Cache] = None,
  1184. inputs_embeds: Optional[torch.FloatTensor] = None,
  1185. use_cache: Optional[bool] = None,
  1186. cache_position: Optional[torch.LongTensor] = None,
  1187. **kwargs: Unpack[FlashAttentionKwargs],
  1188. ) -> Union[tuple, AriaModelOutputWithPast]:
  1189. if inputs_embeds is None:
  1190. inputs_embeds = self.get_input_embeddings()(input_ids)
  1191. # 2. Merge text and images
  1192. if pixel_values is not None and inputs_embeds.shape[1] != 1:
  1193. image_features = self.get_image_features(
  1194. pixel_values=pixel_values,
  1195. pixel_mask=pixel_mask,
  1196. vision_feature_layer=self.config.vision_feature_layer,
  1197. )
  1198. image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
  1199. special_image_mask = self.get_placeholder_mask(
  1200. input_ids, inputs_embeds=inputs_embeds, image_features=image_features
  1201. )
  1202. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
  1203. outputs = self.language_model(
  1204. attention_mask=attention_mask,
  1205. position_ids=position_ids,
  1206. past_key_values=past_key_values,
  1207. inputs_embeds=inputs_embeds,
  1208. use_cache=use_cache,
  1209. cache_position=cache_position,
  1210. **kwargs,
  1211. )
  1212. return AriaModelOutputWithPast(
  1213. last_hidden_state=outputs.last_hidden_state,
  1214. past_key_values=outputs.past_key_values if use_cache else None,
  1215. hidden_states=outputs.hidden_states,
  1216. attentions=outputs.attentions,
  1217. image_hidden_states=image_features if pixel_values is not None else None,
  1218. )
  1219. @auto_docstring(
  1220. custom_intro="""
  1221. Aria model for conditional generation tasks.
  1222. This model combines a vision tower, a multi-modal projector, and a language model
  1223. to perform tasks that involve both image and text inputs.
  1224. """
  1225. )
  1226. class AriaForConditionalGeneration(LlavaForConditionalGeneration):
  1227. def get_image_features(
  1228. self,
  1229. pixel_values: torch.FloatTensor,
  1230. pixel_mask: Optional[torch.FloatTensor] = None,
  1231. vision_feature_layer: int = -1,
  1232. ):
  1233. return self.model.get_image_features(
  1234. pixel_values=pixel_values,
  1235. pixel_mask=pixel_mask,
  1236. vision_feature_layer=vision_feature_layer,
  1237. )
  1238. @can_return_tuple
  1239. @auto_docstring
  1240. def forward(
  1241. self,
  1242. input_ids: Optional[torch.LongTensor] = None,
  1243. pixel_values: Optional[torch.FloatTensor] = None,
  1244. pixel_mask: Optional[torch.LongTensor] = None,
  1245. attention_mask: Optional[torch.Tensor] = None,
  1246. position_ids: Optional[torch.LongTensor] = None,
  1247. past_key_values: Optional[Cache] = None,
  1248. inputs_embeds: Optional[torch.FloatTensor] = None,
  1249. labels: Optional[torch.LongTensor] = None,
  1250. use_cache: Optional[bool] = None,
  1251. logits_to_keep: Union[int, torch.Tensor] = 0,
  1252. cache_position: Optional[torch.LongTensor] = None,
  1253. **kwargs: Unpack[TransformersKwargs],
  1254. ) -> Union[tuple, AriaCausalLMOutputWithPast]:
  1255. r"""
  1256. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1257. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1258. config.vocab_size]` or `model.image_token_id` (where `model` is your instance of `AriaForConditionalGeneration`).
  1259. Tokens with indices set to `model.image_token_id` are ignored (masked), the loss is only
  1260. computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1261. Example:
  1262. ```python
  1263. >>> import requests
  1264. >>> import torch
  1265. >>> from PIL import Image
  1266. >>> from io import BytesIO
  1267. >>> from transformers import AutoProcessor, AutoModel
  1268. >>> from transformers.image_utils import load_image
  1269. >>> # Note that passing the image urls (instead of the actual pil images) to the processor is also possible
  1270. >>> image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
  1271. >>> image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
  1272. >>> image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")
  1273. >>> processor = AutoProcessor.from_pretrained("Rhymes-AI/Aria")
  1274. >>> model = AutoModel.from_pretrained("Rhymes-AI/Aria", dtype=torch.bfloat16, device_map="auto")
  1275. >>> # Create inputs
  1276. >>> messages = [
  1277. ... {
  1278. ... "role": "user",
  1279. ... "content": [
  1280. ... {"type": "image"},
  1281. ... {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
  1282. ... {"type": "image"},
  1283. ... {"type": "text", "text": "What can we see in this image?"},
  1284. ... ]
  1285. ... },
  1286. ... {
  1287. ... "role": "user",
  1288. ... "content": [
  1289. ... {"type": "image"},
  1290. ... {"type": "text", "text": "In which city is that bridge located?"},
  1291. ... ]
  1292. ... }
  1293. ... ]
  1294. >>> prompts = [processor.apply_chat_template([message], add_generation_prompt=True) for message in messages]
  1295. >>> images = [[image1, image2], [image3]]
  1296. >>> inputs = processor(text=prompts, images=images, padding=True, return_tensors="pt").to(model.device)
  1297. >>> # Generate
  1298. >>> generated_ids = model.generate(**inputs, max_new_tokens=256)
  1299. >>> generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
  1300. >>> print(generated_texts[0])
  1301. Assistant: There are buildings, trees, lights, and water visible in this image.
  1302. >>> print(generated_texts[1])
  1303. Assistant: The bridge is in San Francisco.
  1304. ```"""
  1305. outputs = self.model(
  1306. input_ids=input_ids,
  1307. pixel_values=pixel_values,
  1308. pixel_mask=pixel_mask,
  1309. attention_mask=attention_mask,
  1310. position_ids=position_ids,
  1311. past_key_values=past_key_values,
  1312. inputs_embeds=inputs_embeds,
  1313. use_cache=use_cache,
  1314. cache_position=cache_position,
  1315. **kwargs,
  1316. )
  1317. hidden_states = outputs[0]
  1318. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1319. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1320. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1321. loss = None
  1322. if labels is not None:
  1323. loss = self.loss_function(
  1324. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  1325. )
  1326. return AriaCausalLMOutputWithPast(
  1327. loss=loss,
  1328. logits=logits,
  1329. past_key_values=outputs.past_key_values,
  1330. hidden_states=outputs.hidden_states,
  1331. attentions=outputs.attentions,
  1332. )
  1333. def prepare_inputs_for_generation(
  1334. self,
  1335. input_ids,
  1336. past_key_values=None,
  1337. inputs_embeds=None,
  1338. pixel_values=None,
  1339. pixel_mask=None,
  1340. attention_mask=None,
  1341. cache_position=None,
  1342. logits_to_keep=None,
  1343. **kwargs,
  1344. ):
  1345. model_inputs = super().prepare_inputs_for_generation(
  1346. input_ids,
  1347. past_key_values=past_key_values,
  1348. inputs_embeds=inputs_embeds,
  1349. attention_mask=attention_mask,
  1350. cache_position=cache_position,
  1351. logits_to_keep=logits_to_keep,
  1352. **kwargs,
  1353. )
  1354. if cache_position[0] == 0:
  1355. # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
  1356. # Otherwise we need pixel values to be passed to model
  1357. model_inputs["pixel_values"] = pixel_values
  1358. model_inputs["pixel_mask"] = pixel_mask
  1359. return model_inputs
  1360. __all__ = [
  1361. "AriaConfig",
  1362. "AriaTextConfig",
  1363. "AriaImageProcessor",
  1364. "AriaProcessor",
  1365. "AriaForConditionalGeneration",
  1366. "AriaPreTrainedModel",
  1367. "AriaTextPreTrainedModel",
  1368. "AriaTextModel",
  1369. "AriaModel",
  1370. "AriaTextForCausalLM",
  1371. ]