modeling_chameleon.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169
  1. # coding=utf-8
  2. # Copyright 2024 Meta Inc. and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Chameleon model."""
  16. from functools import cached_property
  17. from typing import Callable, Optional, Union
  18. import torch
  19. import torch.nn.functional as F
  20. from torch import nn
  21. from ...activations import ACT2FN
  22. from ...cache_utils import Cache, DynamicCache
  23. from ...generation import GenerationMixin
  24. from ...masking_utils import create_causal_mask
  25. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  28. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  29. from ...processing_utils import Unpack
  30. from ...utils import (
  31. TransformersKwargs,
  32. auto_docstring,
  33. can_return_tuple,
  34. logging,
  35. )
  36. from ...utils.deprecation import deprecate_kwarg
  37. from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig
  38. logger = logging.get_logger(__name__)
  39. # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Chameleon
  40. class ChameleonRMSNorm(nn.Module):
  41. def __init__(self, hidden_size, eps=1e-6):
  42. """
  43. ChameleonRMSNorm is equivalent to T5LayerNorm
  44. """
  45. super().__init__()
  46. self.weight = nn.Parameter(torch.ones(hidden_size))
  47. self.variance_epsilon = eps
  48. def forward(self, hidden_states):
  49. input_dtype = hidden_states.dtype
  50. hidden_states = hidden_states.to(torch.float32)
  51. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  52. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  53. return self.weight * hidden_states.to(input_dtype)
  54. def extra_repr(self):
  55. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  56. # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Chameleon
  57. # TODO(joao): add me back asap :)
  58. class ChameleonRotaryEmbedding(nn.Module):
  59. inv_freq: torch.Tensor # fix linting for `register_buffer`
  60. def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
  61. super().__init__()
  62. self.scaling_factor = scaling_factor
  63. self.dim = dim
  64. self.max_position_embeddings = max_position_embeddings
  65. self.base = base
  66. inv_freq = 1.0 / (
  67. self.base
  68. ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / self.dim)
  69. )
  70. self.register_buffer("inv_freq", inv_freq, persistent=False)
  71. # For BC we register cos and sin cached
  72. self.max_seq_len_cached = max_position_embeddings
  73. @torch.no_grad()
  74. def forward(self, x, position_ids):
  75. # x: [bs, num_attention_heads, seq_len, head_size]
  76. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  77. position_ids_expanded = position_ids[:, None, :].float()
  78. # Force float32 since bfloat16 loses precision on long contexts
  79. # See https://github.com/huggingface/transformers/pull/29285
  80. device_type = x.device.type
  81. device_type = device_type if device_type != "mps" else "cpu"
  82. with torch.autocast(device_type=device_type, enabled=False):
  83. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  84. emb = torch.cat((freqs, freqs), dim=-1)
  85. cos = emb.cos()
  86. sin = emb.sin()
  87. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  88. class ChameleonLinearScalingRotaryEmbedding(ChameleonRotaryEmbedding):
  89. """ChameleonRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
  90. def forward(self, x, position_ids):
  91. # difference to the original RoPE: a scaling factor is applied to the position ids
  92. position_ids = position_ids.float() / self.scaling_factor
  93. cos, sin = super().forward(x, position_ids)
  94. return cos, sin
  95. class ChameleonDynamicNTKScalingRotaryEmbedding(ChameleonRotaryEmbedding):
  96. """ChameleonRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
  97. def forward(self, x, position_ids):
  98. # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
  99. seq_len = torch.max(position_ids) + 1
  100. if seq_len > self.max_position_embeddings:
  101. base = self.base * (
  102. (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
  103. ) ** (self.dim / (self.dim - 2))
  104. inv_freq = 1.0 / (
  105. base
  106. ** (torch.arange(0, self.dim, 2, dtype=torch.int64).to(device=x.device, dtype=torch.float) / self.dim)
  107. )
  108. self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation
  109. cos, sin = super().forward(x, position_ids)
  110. return cos, sin
  111. # Copied from transformers.models.llama.modeling_llama.rotate_half
  112. def rotate_half(x):
  113. """Rotates half the hidden dims of the input."""
  114. x1 = x[..., : x.shape[-1] // 2]
  115. x2 = x[..., x.shape[-1] // 2 :]
  116. return torch.cat((-x2, x1), dim=-1)
  117. # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
  118. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  119. """Applies Rotary Position Embedding to the query and key tensors.
  120. Args:
  121. q (`torch.Tensor`): The query tensor.
  122. k (`torch.Tensor`): The key tensor.
  123. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  124. sin (`torch.Tensor`): The sine part of the rotary embedding.
  125. position_ids (`torch.Tensor`, *optional*):
  126. Deprecated and unused.
  127. unsqueeze_dim (`int`, *optional*, defaults to 1):
  128. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  129. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  130. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  131. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  132. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  133. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  134. Returns:
  135. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  136. """
  137. cos = cos.unsqueeze(unsqueeze_dim)
  138. sin = sin.unsqueeze(unsqueeze_dim)
  139. q_embed = (q * cos) + (rotate_half(q) * sin)
  140. k_embed = (k * cos) + (rotate_half(k) * sin)
  141. return q_embed, k_embed
  142. # Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->Chameleon
  143. class ChameleonMLP(nn.Module):
  144. def __init__(self, config):
  145. super().__init__()
  146. self.config = config
  147. self.hidden_size = config.hidden_size
  148. self.intermediate_size = config.intermediate_size
  149. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  150. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  151. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  152. self.act_fn = ACT2FN[config.hidden_act]
  153. # Ignore copy
  154. def forward(self, x):
  155. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  156. return down_proj
  157. class ChameleonLayerNorm(nn.LayerNorm):
  158. """
  159. LayerNorm but computes stats only over the last dim because Chameleon applies gamma and beta
  160. from each shard separately to each head, instead of reducing. We can apply each head's own
  161. gamma/beta by repeat-interleaving weights from each shard, but the stats have to be computed
  162. in the last dimension. This module applies gamma/beta manually to fulfill this requirement.
  163. """
  164. def __init__(self, hidden_size, *args, **kwargs):
  165. super().__init__(hidden_size, *args, **kwargs)
  166. self.normalized_shape = (hidden_size[-1],)
  167. def forward(self, hidden_states):
  168. hidden_states = F.layer_norm(hidden_states, self.normalized_shape, None, None, eps=1e-5)
  169. hidden_states = hidden_states * self.weight + self.bias
  170. return hidden_states
  171. # Copied from transformers.models.llama.modeling_llama.repeat_kv
  172. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  173. """
  174. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  175. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  176. """
  177. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  178. if n_rep == 1:
  179. return hidden_states
  180. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  181. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  182. # Copied from transformers.models.llama.modeling_llama.eager_attention_forward
  183. def eager_attention_forward(
  184. module: nn.Module,
  185. query: torch.Tensor,
  186. key: torch.Tensor,
  187. value: torch.Tensor,
  188. attention_mask: Optional[torch.Tensor],
  189. scaling: float,
  190. dropout: float = 0.0,
  191. **kwargs: Unpack[TransformersKwargs],
  192. ):
  193. key_states = repeat_kv(key, module.num_key_value_groups)
  194. value_states = repeat_kv(value, module.num_key_value_groups)
  195. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  196. if attention_mask is not None:
  197. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  198. attn_weights = attn_weights + causal_mask
  199. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  200. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  201. attn_output = torch.matmul(attn_weights, value_states)
  202. attn_output = attn_output.transpose(1, 2).contiguous()
  203. return attn_output, attn_weights
  204. class ChameleonAttention(nn.Module):
  205. """Multi-headed attention from 'Attention Is All You Need' paper"""
  206. def __init__(self, config: ChameleonConfig, layer_idx: Optional[int] = None):
  207. super().__init__()
  208. self.config = config
  209. self.layer_idx = layer_idx
  210. if layer_idx is None:
  211. logger.warning_once(
  212. f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
  213. "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
  214. "when creating this class."
  215. )
  216. self.attention_dropout = config.attention_dropout
  217. self.hidden_size = config.hidden_size
  218. self.num_heads = config.num_attention_heads
  219. self.head_dim = self.hidden_size // self.num_heads
  220. self.num_key_value_heads = config.num_key_value_heads
  221. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  222. self.max_position_embeddings = config.max_position_embeddings
  223. self.rope_theta = config.rope_theta
  224. self.is_causal = True
  225. self.model_parallel_size = config.model_parallel_size
  226. self.scaling = self.head_dim**-0.5
  227. if (self.head_dim * self.num_heads) != self.hidden_size:
  228. raise ValueError(
  229. f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
  230. f" and `num_heads`: {self.num_heads})."
  231. )
  232. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
  233. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  234. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
  235. self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
  236. self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim))
  237. self.k_norm = ChameleonLayerNorm((self.num_key_value_heads, self.head_dim))
  238. self._init_rope()
  239. # copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Chameleon
  240. # TODO(joao): add me back asap :)
  241. def _init_rope(self):
  242. if self.config.rope_scaling is None:
  243. self.rotary_emb = ChameleonRotaryEmbedding(
  244. self.head_dim,
  245. max_position_embeddings=self.max_position_embeddings,
  246. base=self.rope_theta,
  247. )
  248. else:
  249. scaling_type = self.config.rope_scaling["type"]
  250. scaling_factor = self.config.rope_scaling["factor"]
  251. if scaling_type == "linear":
  252. self.rotary_emb = ChameleonLinearScalingRotaryEmbedding(
  253. self.head_dim,
  254. max_position_embeddings=self.max_position_embeddings,
  255. scaling_factor=scaling_factor,
  256. base=self.rope_theta,
  257. )
  258. elif scaling_type == "dynamic":
  259. self.rotary_emb = ChameleonDynamicNTKScalingRotaryEmbedding(
  260. self.head_dim,
  261. max_position_embeddings=self.max_position_embeddings,
  262. scaling_factor=scaling_factor,
  263. base=self.rope_theta,
  264. )
  265. else:
  266. raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
  267. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  268. def forward(
  269. self,
  270. hidden_states: torch.Tensor,
  271. attention_mask: Optional[torch.Tensor] = None,
  272. position_ids: Optional[torch.LongTensor] = None,
  273. past_key_values: Optional[Cache] = None,
  274. output_attentions: bool = False,
  275. use_cache: bool = False,
  276. cache_position: Optional[torch.LongTensor] = None,
  277. **kwargs,
  278. ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
  279. bsz, q_len, _ = hidden_states.size()
  280. query_states = self.q_proj(hidden_states)
  281. key_states = self.k_proj(hidden_states)
  282. value_states = self.v_proj(hidden_states)
  283. query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
  284. query_states = self.q_norm(query_states)
  285. key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
  286. key_states = self.k_norm(key_states)
  287. query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  288. key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  289. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  290. cos, sin = self.rotary_emb(value_states, position_ids)
  291. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  292. if past_key_values is not None:
  293. # sin and cos are specific to RoPE models; position_ids needed for the static cache
  294. cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
  295. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  296. attention_interface: Callable = eager_attention_forward
  297. if self.config._attn_implementation != "eager":
  298. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  299. attn_output, attn_weights = attention_interface(
  300. self,
  301. query_states,
  302. key_states,
  303. value_states,
  304. attention_mask,
  305. dropout=0.0 if not self.training else self.attention_dropout,
  306. scaling=self.scaling,
  307. **kwargs,
  308. )
  309. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  310. attn_output = self.o_proj(attn_output)
  311. return attn_output, attn_weights
  312. # copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Chameleon, LLAMA->CHAMELEON
  313. class ChameleonDecoderLayer(GradientCheckpointingLayer):
  314. def __init__(self, config: ChameleonConfig, layer_idx: int):
  315. super().__init__()
  316. self.hidden_size = config.hidden_size
  317. self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx)
  318. self.mlp = ChameleonMLP(config)
  319. self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  320. self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  321. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  322. def forward(
  323. self,
  324. hidden_states: torch.Tensor,
  325. attention_mask: Optional[torch.Tensor] = None,
  326. position_ids: Optional[torch.LongTensor] = None,
  327. past_key_values: Optional[Cache] = None,
  328. output_attentions: Optional[bool] = False,
  329. use_cache: Optional[bool] = False,
  330. cache_position: Optional[torch.LongTensor] = None,
  331. **kwargs,
  332. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  333. """
  334. Args:
  335. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  336. attention_mask (`torch.FloatTensor`, *optional*):
  337. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  338. query_sequence_length, key_sequence_length)` if default attention is used.
  339. output_attentions (`bool`, *optional*):
  340. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  341. returned tensors for more detail.
  342. use_cache (`bool`, *optional*):
  343. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  344. (see `past_key_values`).
  345. past_key_values (`Cache`, *optional*): cached past key and value projection states
  346. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  347. Indices depicting the position of the input sequence tokens in the sequence
  348. kwargs (`dict`, *optional*):
  349. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  350. into the model
  351. """
  352. residual = hidden_states
  353. hidden_states = self.input_layernorm(hidden_states)
  354. # Self Attention
  355. hidden_states, self_attn_weights = self.self_attn(
  356. hidden_states=hidden_states,
  357. attention_mask=attention_mask,
  358. position_ids=position_ids,
  359. past_key_values=past_key_values,
  360. output_attentions=output_attentions,
  361. use_cache=use_cache,
  362. cache_position=cache_position,
  363. **kwargs,
  364. )
  365. hidden_states = residual + hidden_states
  366. # Fully Connected
  367. residual = hidden_states
  368. hidden_states = self.post_attention_layernorm(hidden_states)
  369. hidden_states = self.mlp(hidden_states)
  370. hidden_states = residual + hidden_states
  371. outputs = (hidden_states,)
  372. if output_attentions:
  373. outputs += (self_attn_weights,)
  374. return outputs
  375. class ChameleonSwinDecoderLayer(GradientCheckpointingLayer):
  376. def __init__(self, config: ChameleonConfig, layer_idx: int):
  377. super().__init__()
  378. self.hidden_size = config.hidden_size
  379. self.self_attn = ChameleonAttention(config=config, layer_idx=layer_idx)
  380. self.mlp = ChameleonMLP(config)
  381. self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  382. self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  383. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  384. def forward(
  385. self,
  386. hidden_states: torch.Tensor,
  387. attention_mask: Optional[torch.Tensor] = None,
  388. position_ids: Optional[torch.LongTensor] = None,
  389. past_key_values: Optional[Cache] = None,
  390. output_attentions: Optional[bool] = False,
  391. use_cache: Optional[bool] = False,
  392. cache_position: Optional[torch.LongTensor] = None,
  393. **kwargs,
  394. ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
  395. """
  396. Args:
  397. hidden_states (`torch.FloatTensor`):
  398. input to the layer of shape `(batch, seq_len, embed_dim)`
  399. attention_mask (`torch.FloatTensor`, *optional*):
  400. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  401. query_sequence_length, key_sequence_length)` if default attention is used.
  402. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  403. Indices of positions of each input sequence tokens in the position embeddings
  404. past_key_values (`Cache`, *optional*): cached past key and value projection states
  405. output_attentions (`bool`, *optional*):
  406. Whether or not to return the attentions tensors of all attention layers. See `attentions` under
  407. returned tensors for more detail.
  408. use_cache (`bool`, *optional*):
  409. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  410. (see `past_key_values`).
  411. cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
  412. Indices depicting the position of the input sequence tokens in the sequence.
  413. """
  414. residual = hidden_states
  415. # Self Attention
  416. hidden_states, self_attn_weights = self.self_attn(
  417. hidden_states=hidden_states,
  418. attention_mask=attention_mask,
  419. position_ids=position_ids,
  420. past_key_values=past_key_values,
  421. output_attentions=output_attentions,
  422. use_cache=use_cache,
  423. cache_position=cache_position,
  424. **kwargs,
  425. )
  426. hidden_states = self.input_layernorm(hidden_states)
  427. hidden_states = residual + hidden_states
  428. # Fully Connected
  429. residual = hidden_states
  430. hidden_states = self.mlp(hidden_states)
  431. hidden_states = self.post_attention_layernorm(hidden_states)
  432. hidden_states = residual + hidden_states
  433. outputs = (hidden_states,)
  434. if output_attentions:
  435. outputs += (self_attn_weights,)
  436. return outputs
  437. class ChameleonVQVAEVectorQuantizer(nn.Module):
  438. """
  439. A module for vector quantization using learned embedding vectors.
  440. This module implements the quantization process similar to te one described in
  441. the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
  442. input vectors into discrete codebook vectors, which are learned during training.
  443. Current implementation improves over previous ones by avoiding costly matrix multiplications
  444. and allowing for post-hoc remapping of indices.
  445. """
  446. def __init__(self, config):
  447. super().__init__()
  448. self.num_embeddings = config.num_embeddings
  449. self.embedding_dim = config.embed_dim
  450. self.beta = getattr(config, "beta", 0.25)
  451. self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
  452. def forward(self, hidden_state: torch.Tensor):
  453. hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
  454. hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
  455. # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
  456. distances = (
  457. torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
  458. + torch.sum(self.embedding.weight**2, dim=1)
  459. - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, self.embedding.weight.transpose(0, 1))
  460. )
  461. min_encoding_indices = torch.argmin(distances, dim=1)
  462. hidden_state_quant = self.embedding(min_encoding_indices).view(hidden_state.shape)
  463. # compute loss for embedding
  464. loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean(
  465. (hidden_state_quant - hidden_state.detach()) ** 2
  466. )
  467. # preserve gradients
  468. hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
  469. # reshape back to match original input shape
  470. hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
  471. return hidden_state_quant, loss, min_encoding_indices
  472. class ChameleonVQVAEEncoderConvDownsample(nn.Module):
  473. def __init__(self, in_channels):
  474. super().__init__()
  475. self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
  476. def forward(self, hidden_states):
  477. # no asymmetric padding in torch conv, must do it ourselves
  478. hidden_states = F.pad(hidden_states, pad=(0, 1, 0, 1), mode="constant", value=0)
  479. hidden_states = self.conv(hidden_states)
  480. return hidden_states
  481. class ChameleonVQVAEEncoderResnetBlock(nn.Module):
  482. def __init__(
  483. self,
  484. config,
  485. in_channels,
  486. out_channels=None,
  487. conv_shortcut=False,
  488. ):
  489. super().__init__()
  490. self.in_channels = in_channels
  491. self.out_channels = in_channels if out_channels is None else out_channels
  492. self.use_conv_shortcut = conv_shortcut
  493. self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
  494. self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
  495. self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
  496. self.dropout = torch.nn.Dropout(config.dropout)
  497. self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
  498. if self.in_channels != self.out_channels:
  499. if self.use_conv_shortcut:
  500. self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
  501. else:
  502. self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
  503. def forward(self, hidden_states):
  504. residual = hidden_states
  505. hidden_states = self.norm1(hidden_states)
  506. hidden_states *= torch.sigmoid(hidden_states)
  507. hidden_states = self.conv1(hidden_states)
  508. hidden_states = self.norm2(hidden_states)
  509. hidden_states *= torch.sigmoid(hidden_states)
  510. hidden_states = self.dropout(hidden_states)
  511. hidden_states = self.conv2(hidden_states)
  512. if self.in_channels != self.out_channels:
  513. if self.use_conv_shortcut:
  514. residual = self.conv_shortcut(residual)
  515. else:
  516. residual = self.nin_shortcut(residual)
  517. return residual + hidden_states
  518. class ChameleonVQVAEEncoderAttnBlock(nn.Module):
  519. def __init__(self, in_channels):
  520. super().__init__()
  521. self.in_channels = in_channels
  522. self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
  523. self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  524. self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  525. self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  526. self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
  527. def forward(self, hidden_states):
  528. residual = hidden_states
  529. hidden_states = self.norm(hidden_states)
  530. query_states = self.q(hidden_states)
  531. key_states = self.k(hidden_states)
  532. value_states = self.v(hidden_states)
  533. # compute attention
  534. batch_size, channels, height, width = query_states.shape
  535. query_states = query_states.reshape(batch_size, channels, height * width).permute(0, 2, 1)
  536. key_states = key_states.reshape(batch_size, channels, height * width)
  537. attn_weights = torch.bmm(query_states, key_states)
  538. attn_weights = attn_weights * (int(channels) ** (-0.5))
  539. attn_weights = F.softmax(attn_weights, dim=2)
  540. # attend to values
  541. value_states = value_states.reshape(batch_size, channels, height * width)
  542. attn_weights = attn_weights.permute(0, 2, 1)
  543. attn_output = torch.bmm(value_states, attn_weights).reshape(batch_size, channels, height, width)
  544. attn_output = self.proj_out(attn_output)
  545. return residual + attn_output
  546. class ChameleonVQVAEEncoder(nn.Module):
  547. def __init__(self, config):
  548. super().__init__()
  549. self.num_resolutions = len(config.channel_multiplier)
  550. self.num_res_blocks = config.num_res_blocks
  551. base_channels = config.base_channels
  552. resolution = config.resolution
  553. in_channels = config.in_channels
  554. double_latent = config.double_latent
  555. latent_channels = config.latent_channels
  556. channel_multiplier = config.channel_multiplier
  557. self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
  558. curr_res = resolution
  559. in_channel_multiplier = (1,) + tuple(channel_multiplier)
  560. self.in_channel_multiplier = in_channel_multiplier
  561. self.down = nn.ModuleList()
  562. for i_level in range(self.num_resolutions):
  563. block = nn.ModuleList()
  564. attn = nn.ModuleList()
  565. block_in = base_channels * in_channel_multiplier[i_level]
  566. block_out = base_channels * channel_multiplier[i_level]
  567. for i_block in range(self.num_res_blocks):
  568. block.append(
  569. ChameleonVQVAEEncoderResnetBlock(
  570. config=config,
  571. in_channels=block_in,
  572. out_channels=block_out,
  573. )
  574. )
  575. block_in = block_out
  576. if (
  577. config.attn_resolutions is not None
  578. and curr_res in config.attn_resolutions
  579. and config.attn_type == "vanilla"
  580. ):
  581. attn.append(ChameleonVQVAEEncoderAttnBlock(block_in))
  582. down = nn.Module()
  583. down.block = block
  584. down.attn = attn
  585. if i_level != self.num_resolutions - 1:
  586. down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in)
  587. curr_res = curr_res // 2
  588. self.down.append(down)
  589. self.mid = nn.Module()
  590. self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock(
  591. config=config,
  592. in_channels=block_in,
  593. out_channels=block_in,
  594. )
  595. self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock(block_in) if config.attn_type == "vanilla" else nn.Identity()
  596. self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock(
  597. config=config,
  598. in_channels=block_in,
  599. out_channels=block_in,
  600. )
  601. self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
  602. self.conv_out = torch.nn.Conv2d(
  603. block_in,
  604. 2 * latent_channels if double_latent else latent_channels,
  605. kernel_size=3,
  606. stride=1,
  607. padding=1,
  608. )
  609. def forward(self, pixel_values: torch.LongTensor):
  610. # downsampling
  611. hidden_states = [self.conv_in(pixel_values)]
  612. for i_level in range(self.num_resolutions):
  613. for i_block in range(self.num_res_blocks):
  614. hidden_state = self.down[i_level].block[i_block](
  615. hidden_states[-1],
  616. )
  617. if len(self.down[i_level].attn) > 0:
  618. hidden_state = self.down[i_level].attn[i_block](hidden_state)
  619. hidden_states.append(hidden_state)
  620. if i_level != self.num_resolutions - 1:
  621. hidden_states.append(self.down[i_level].downsample(hidden_states[-1]))
  622. # middle
  623. last_hidden_state = hidden_states[-1]
  624. last_hidden_state = self.mid.block_1(last_hidden_state)
  625. last_hidden_state = self.mid.attn_1(last_hidden_state)
  626. last_hidden_state = self.mid.block_2(last_hidden_state)
  627. # end
  628. last_hidden_state = self.norm_out(last_hidden_state)
  629. last_hidden_state *= torch.sigmoid(last_hidden_state)
  630. last_hidden_state = self.conv_out(last_hidden_state)
  631. return last_hidden_state
  632. class ChameleonImageVocabularyMapping:
  633. """
  634. A class for mapping discrete image tokens from VQGAN to BPE tokens.
  635. """
  636. def __init__(self, vocab_map):
  637. self.vocab_map = vocab_map
  638. self.image_token_id = vocab_map.get("<image>")
  639. @cached_property
  640. def val2name(self):
  641. return {v: k for k, v in self.vocab_map.items()}
  642. @cached_property
  643. def image_tokens(self):
  644. return sorted([val for name, val in self.vocab_map.items() if name.startswith("IMGIMG")])
  645. @cached_property
  646. def bpe2img(self):
  647. img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)}
  648. def remap(old_name: str) -> str:
  649. return "".join(img_tkn_chr_mapping.get(c, c) for c in old_name[len("IMGIMG") : -1])
  650. return {tok: int(remap(self.val2name[tok])) for tok in self.image_tokens}
  651. @cached_property
  652. def img2bpe(self):
  653. return {v: k for k, v in self.bpe2img.items()}
  654. @cached_property
  655. def bpe2img_search_tensors(self):
  656. return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor(sorted(self.bpe2img.values()))
  657. @cached_property
  658. def img2bpe_mapping_tensor(self):
  659. mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
  660. for k, v in self.img2bpe.items():
  661. mapping[k] = v
  662. return mapping
  663. def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor:
  664. device = img_batch.device
  665. img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
  666. return img_tokens.to(device)
  667. @auto_docstring
  668. class ChameleonPreTrainedModel(PreTrainedModel):
  669. config: ChameleonConfig
  670. base_model_prefix = "model"
  671. supports_gradient_checkpointing = True
  672. _no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"]
  673. _skip_keys_device_placement = ["past_key_values", "causal_mask"]
  674. _supports_flash_attn = True
  675. _supports_sdpa = True
  676. _can_compile_fullgraph = True
  677. _supports_param_buffer_assignment = False
  678. _supports_flex_attn = True
  679. _supports_attention_backend = True
  680. @auto_docstring(
  681. custom_intro="""
  682. The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens.
  683. This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
  684. [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
  685. Taigman](https://huggingface.co/papers/2203.13131).
  686. """
  687. )
  688. class ChameleonVQVAE(ChameleonPreTrainedModel):
  689. config: ChameleonVQVAEConfig
  690. _no_split_modules = [
  691. "ChameleonVQVAEVectorQuantizer",
  692. "ChameleonVQVAEEncoderAttnBlock",
  693. "ChameleonVQVAEEncoderResnetBlock",
  694. ]
  695. def __init__(self, config: ChameleonVQVAEConfig):
  696. super().__init__(config)
  697. self.encoder = ChameleonVQVAEEncoder(config)
  698. self.quantize = ChameleonVQVAEVectorQuantizer(config)
  699. self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
  700. self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
  701. self.eval() # Chameleon's VQ model is frozen
  702. def encode(self, pixel_values: torch.LongTensor):
  703. hidden_states = self.encoder(pixel_values)
  704. hidden_states = self.quant_conv(hidden_states)
  705. quant, emb_loss, indices = self.quantize(hidden_states)
  706. return quant, emb_loss, indices
  707. @auto_docstring
  708. class ChameleonModel(ChameleonPreTrainedModel):
  709. def __init__(self, config: ChameleonConfig):
  710. super().__init__(config)
  711. self.padding_idx = config.pad_token_id
  712. self.vocab_size = config.vocab_size
  713. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  714. self.vocabulary_mapping = ChameleonImageVocabularyMapping(config.vocabulary_map)
  715. decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm else ChameleonSwinDecoderLayer
  716. self.layers = nn.ModuleList(
  717. [decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  718. )
  719. self.norm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  720. self.vqmodel = ChameleonVQVAE._from_config(config.vq_config)
  721. self.gradient_checkpointing = False
  722. # Initialize weights and apply final processing
  723. self.post_init()
  724. def get_image_tokens(self, pixel_values: torch.FloatTensor):
  725. """
  726. Tokenizes images into discrete tokens with VQGAN module. Converts
  727. obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
  728. special tokens.
  729. Args:
  730. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
  731. The tensors corresponding to the input images.
  732. """
  733. batch_size = pixel_values.shape[0]
  734. _, _, image_toks = self.vqmodel.encode(pixel_values)
  735. bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks)
  736. bpe_toks = bpe_toks.view(batch_size, -1)
  737. return bpe_toks
  738. def get_image_features(self, pixel_values: torch.FloatTensor):
  739. """
  740. Tokenizes images into discrete tokens with VQGAN module and embeds
  741. them with text embeddings layer
  742. Args:
  743. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
  744. The tensors corresponding to the input images.
  745. """
  746. image_tokens = self.get_image_tokens(pixel_values)
  747. vision_embeddings = self.get_input_embeddings()(image_tokens)
  748. return vision_embeddings
  749. def get_placeholder_mask(
  750. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  751. ):
  752. """
  753. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  754. equal to the length of multimodal features. If the lengths are different, an error is raised.
  755. """
  756. if input_ids is None:
  757. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  758. torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  759. )
  760. special_image_mask = special_image_mask.all(-1)
  761. else:
  762. special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
  763. n_image_tokens = special_image_mask.sum()
  764. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  765. n_image_features = image_features.shape[0] * image_features.shape[1]
  766. if inputs_embeds[special_image_mask].numel() != image_features.numel():
  767. raise ValueError(
  768. f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
  769. )
  770. return special_image_mask
  771. @auto_docstring
  772. def forward(
  773. self,
  774. input_ids: Optional[torch.LongTensor] = None,
  775. pixel_values: Optional[torch.FloatTensor] = None,
  776. attention_mask: Optional[torch.Tensor] = None,
  777. position_ids: Optional[torch.LongTensor] = None,
  778. past_key_values: Optional[Cache] = None,
  779. inputs_embeds: Optional[torch.FloatTensor] = None,
  780. use_cache: Optional[bool] = None,
  781. output_attentions: Optional[bool] = None,
  782. output_hidden_states: Optional[bool] = None,
  783. return_dict: Optional[bool] = None,
  784. cache_position: Optional[torch.LongTensor] = None,
  785. **kwargs: Unpack[FlashAttentionKwargs],
  786. ) -> Union[tuple, BaseModelOutputWithPast]:
  787. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  788. output_hidden_states = (
  789. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  790. )
  791. use_cache = use_cache if use_cache is not None else self.config.use_cache
  792. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  793. if self.gradient_checkpointing and self.training and use_cache:
  794. logger.warning_once(
  795. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
  796. )
  797. use_cache = False
  798. if (input_ids is None) ^ (inputs_embeds is not None):
  799. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  800. if inputs_embeds is None:
  801. inputs_embeds = self.embed_tokens(input_ids)
  802. if pixel_values is not None:
  803. image_embeds = self.get_image_features(pixel_values)
  804. special_image_mask = self.get_placeholder_mask(
  805. input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
  806. )
  807. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds)
  808. # torch.jit.trace() doesn't support cache objects in the output
  809. if use_cache and past_key_values is None and not torch.jit.is_tracing():
  810. past_key_values = DynamicCache(config=self.config)
  811. if cache_position is None:
  812. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  813. cache_position = torch.arange(
  814. past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
  815. )
  816. if position_ids is None:
  817. position_ids = cache_position.unsqueeze(0)
  818. causal_mask = create_causal_mask(
  819. config=self.config,
  820. input_embeds=inputs_embeds,
  821. attention_mask=attention_mask,
  822. cache_position=cache_position,
  823. past_key_values=past_key_values,
  824. position_ids=position_ids,
  825. )
  826. # embed positions
  827. hidden_states = inputs_embeds
  828. # decoder layers
  829. all_hidden_states = () if output_hidden_states else None
  830. all_self_attns = () if output_attentions else None
  831. for decoder_layer in self.layers:
  832. if output_hidden_states:
  833. all_hidden_states += (hidden_states,)
  834. layer_outputs = decoder_layer(
  835. hidden_states,
  836. attention_mask=causal_mask,
  837. position_ids=position_ids,
  838. past_key_values=past_key_values,
  839. output_attentions=output_attentions,
  840. use_cache=use_cache,
  841. cache_position=cache_position,
  842. **kwargs,
  843. )
  844. hidden_states = layer_outputs[0]
  845. if output_attentions:
  846. all_self_attns += (layer_outputs[1],)
  847. hidden_states = self.norm(hidden_states)
  848. # add hidden states from the last decoder layer
  849. if output_hidden_states:
  850. all_hidden_states += (hidden_states,)
  851. if not return_dict:
  852. return tuple(
  853. v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
  854. )
  855. return BaseModelOutputWithPast(
  856. last_hidden_state=hidden_states,
  857. past_key_values=past_key_values,
  858. hidden_states=all_hidden_states,
  859. attentions=all_self_attns,
  860. )
  861. @auto_docstring(
  862. custom_intro="""
  863. Chameleon Model with a head on top used for outputting logits for next token prediction.
  864. """
  865. )
  866. class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixin):
  867. _tied_weights_keys = ["lm_head.weight"]
  868. def __init__(self, config):
  869. super().__init__(config)
  870. self.model = ChameleonModel(config)
  871. self.vocab_size = config.vocab_size
  872. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  873. # Initialize weights and apply final processing
  874. self.post_init()
  875. def get_image_tokens(self, pixel_values):
  876. return self.model.get_image_tokens(pixel_values)
  877. def get_image_features(self, pixel_values):
  878. return self.model.get_image_features(pixel_values)
  879. @can_return_tuple
  880. @auto_docstring
  881. def forward(
  882. self,
  883. input_ids: Optional[torch.LongTensor] = None,
  884. pixel_values: Optional[torch.FloatTensor] = None,
  885. attention_mask: Optional[torch.Tensor] = None,
  886. position_ids: Optional[torch.LongTensor] = None,
  887. past_key_values: Optional[Cache] = None,
  888. inputs_embeds: Optional[torch.FloatTensor] = None,
  889. labels: Optional[torch.LongTensor] = None,
  890. use_cache: Optional[bool] = None,
  891. output_attentions: Optional[bool] = None,
  892. output_hidden_states: Optional[bool] = None,
  893. cache_position: Optional[torch.LongTensor] = None,
  894. **kwargs: Unpack[TransformersKwargs],
  895. ) -> Union[tuple, CausalLMOutputWithPast]:
  896. r"""
  897. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  898. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  899. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  900. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  901. Example:
  902. ```python
  903. >>> from transformers import ChameleonProcessor, ChameleonForConditionalGeneration
  904. >>> import torch
  905. >>> import requests
  906. >>> from PIL import Image
  907. >>> model = ChameleonForConditionalGeneration.from_pretrained("facebook/chameleon-7b", dtype=torch.bfloat16)
  908. >>> processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
  909. >>> prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
  910. >>> image = Image.open(requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw)
  911. >>> image_2 = Image.open(requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw)
  912. >>> inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.bfloat16)
  913. >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
  914. >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  915. ```"""
  916. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  917. output_hidden_states = (
  918. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  919. )
  920. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  921. outputs = self.model(
  922. input_ids=input_ids,
  923. pixel_values=pixel_values,
  924. attention_mask=attention_mask,
  925. position_ids=position_ids,
  926. past_key_values=past_key_values,
  927. inputs_embeds=inputs_embeds,
  928. use_cache=use_cache,
  929. output_attentions=output_attentions,
  930. output_hidden_states=output_hidden_states,
  931. return_dict=True,
  932. cache_position=cache_position,
  933. **kwargs,
  934. )
  935. hidden_states = outputs[0]
  936. logits = self.lm_head(hidden_states)
  937. # Disallow image tokens which does not include special begin-image and end-image tokens
  938. image_tokens = self.model.vocabulary_mapping.image_tokens
  939. logits[:, :, image_tokens] = torch.finfo(logits.dtype).min
  940. loss = None
  941. if labels is not None:
  942. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
  943. return CausalLMOutputWithPast(
  944. loss=loss,
  945. logits=logits,
  946. past_key_values=outputs.past_key_values,
  947. hidden_states=outputs.hidden_states,
  948. attentions=outputs.attentions,
  949. )
  950. def prepare_inputs_for_generation(
  951. self,
  952. input_ids,
  953. pixel_values=None,
  954. past_key_values=None,
  955. attention_mask=None,
  956. inputs_embeds=None,
  957. cache_position=None,
  958. position_ids=None,
  959. use_cache=True,
  960. **kwargs,
  961. ):
  962. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  963. model_inputs = super().prepare_inputs_for_generation(
  964. input_ids,
  965. pixel_values=pixel_values,
  966. past_key_values=past_key_values,
  967. attention_mask=attention_mask,
  968. inputs_embeds=inputs_embeds,
  969. cache_position=cache_position,
  970. position_ids=position_ids,
  971. use_cache=use_cache,
  972. **kwargs,
  973. )
  974. if cache_position[0] != 0:
  975. # If we're in cached decoding stage, pixel values should be `None` because input ids do not contain special image token anymore
  976. # Otherwise we need pixel values to be passed to model
  977. model_inputs["pixel_values"] = None
  978. return model_inputs
  979. __all__ = ["ChameleonForConditionalGeneration", "ChameleonModel", "ChameleonPreTrainedModel", "ChameleonVQVAE"]