modular_emu3.py 45 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222
  1. # coding=utf-8
  2. # Copyright 2024 HuggingFace Inc. team. All rights reserved.
  3. #
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import math
  17. from functools import cached_property
  18. from typing import Optional, Union
  19. import torch
  20. import torch.nn as nn
  21. import torch.nn.functional as F
  22. from ...cache_utils import Cache
  23. from ...generation import GenerationMixin
  24. from ...modeling_outputs import CausalLMOutputWithPast
  25. from ...modeling_utils import PreTrainedModel
  26. from ...processing_utils import Unpack
  27. from ...utils import auto_docstring, can_return_tuple, logging
  28. from ...utils.deprecation import deprecate_kwarg
  29. from ..chameleon.modeling_chameleon import (
  30. ChameleonPreTrainedModel,
  31. ChameleonVQVAEEncoderConvDownsample,
  32. )
  33. from ..llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, TransformersKwargs
  34. from ..siglip.modeling_siglip import SiglipAttention
  35. from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig
  36. logger = logging.get_logger(__name__)
  37. class Emu3Attention(LlamaAttention):
  38. pass
  39. # Has extra dropout which no other model in the library has
  40. class Emu3DecoderLayer(LlamaDecoderLayer):
  41. def __init__(self, config: Emu3Config, layer_idx: int):
  42. super().__init__(config, layer_idx)
  43. self.dropout = nn.Dropout(config.attention_dropout)
  44. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  45. def forward(
  46. self,
  47. hidden_states: torch.Tensor,
  48. attention_mask: Optional[torch.Tensor] = None,
  49. position_ids: Optional[torch.LongTensor] = None,
  50. past_key_values: Optional[Cache] = None,
  51. use_cache: Optional[bool] = False,
  52. cache_position: Optional[torch.LongTensor] = None,
  53. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  54. **kwargs: Unpack[TransformersKwargs],
  55. ) -> torch.Tensor:
  56. residual = hidden_states
  57. hidden_states = self.input_layernorm(hidden_states)
  58. hidden_states, _ = self.self_attn(
  59. hidden_states=hidden_states,
  60. attention_mask=attention_mask,
  61. position_ids=position_ids,
  62. past_key_values=past_key_values,
  63. use_cache=use_cache,
  64. cache_position=cache_position,
  65. position_embeddings=position_embeddings,
  66. **kwargs,
  67. )
  68. hidden_states = residual + self.dropout(hidden_states)
  69. residual = hidden_states
  70. hidden_states = self.post_attention_layernorm(hidden_states)
  71. hidden_states = self.mlp(hidden_states)
  72. hidden_states = residual + self.dropout(hidden_states)
  73. return hidden_states
  74. class Emu3VQVAEVectorQuantizer(nn.Module):
  75. """
  76. A module for vector quantization using learned embedding vectors.
  77. This module implements the quantization process similar to te one described in
  78. the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
  79. input vectors into discrete codebook vectors, which are learned during training.
  80. Current implementation improves over previous ones by avoiding costly matrix multiplications
  81. and allowing for post-hoc remapping of indices.
  82. """
  83. def __init__(self, config: Emu3VQVAEConfig):
  84. super().__init__()
  85. self.embedding = nn.Embedding(config.codebook_size, config.embed_dim)
  86. self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size)
  87. def forward(self, hidden_state: torch.Tensor):
  88. batch_size, temporal, channels, height, width = hidden_state.shape
  89. hidden_state = hidden_state.permute(0, 1, 3, 4, 2).contiguous()
  90. hidden_state_flattened = hidden_state.view(-1, channels)
  91. # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
  92. hidden_state_sum = torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
  93. embedding_sum = torch.sum(self.embedding.weight**2, dim=1)
  94. # "bd,dn->bn",
  95. distances = 2 * torch.matmul(hidden_state_flattened, self.embedding.weight.transpose(0, 1))
  96. distances = hidden_state_sum + embedding_sum - distances
  97. min_encoding_indices = torch.argmin(distances, dim=1)
  98. min_encoding_indices = min_encoding_indices.view(batch_size, temporal, height, width)
  99. return min_encoding_indices
  100. class Emu3VQVAEEncoderConvDownsample(ChameleonVQVAEEncoderConvDownsample):
  101. pass
  102. class Emu3VQVAEEncoderConvUpsample(nn.Module):
  103. def __init__(self, in_channels):
  104. super().__init__()
  105. self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
  106. def forward(self, hidden_states):
  107. hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
  108. hidden_states = self.conv(hidden_states)
  109. return hidden_states
  110. class Emu3VQVAEConv3d(nn.Module):
  111. def __init__(
  112. self,
  113. in_channel: int,
  114. out_channel: int,
  115. kernel_size: tuple[int],
  116. stride: tuple[int],
  117. ):
  118. super().__init__()
  119. padding_sizes = [one_kernel - one_stride for one_kernel, one_stride in zip(kernel_size[1:], stride[1:])]
  120. self.padding = ()
  121. for pad_size in padding_sizes[::-1]:
  122. self.padding += (pad_size // 2 + pad_size % 2, pad_size // 2)
  123. self.padding += (2, 0)
  124. self.conv = nn.Conv3d(
  125. in_channel,
  126. out_channel,
  127. kernel_size,
  128. stride=stride,
  129. )
  130. def forward(self, hidden_states: torch.Tensor):
  131. hidden_states = F.pad(hidden_states, self.padding)
  132. hidden_states = self.conv(hidden_states)
  133. return hidden_states
  134. class Emu3VQVAESpatialNorm(nn.Module):
  135. def __init__(
  136. self,
  137. in_channels: int,
  138. out_channels: int,
  139. ):
  140. super().__init__()
  141. self.norm_layer = nn.GroupNorm(
  142. num_channels=out_channels,
  143. num_groups=32,
  144. eps=1e-6,
  145. affine=True,
  146. )
  147. self.conv_y = nn.Conv2d(
  148. in_channels,
  149. out_channels,
  150. kernel_size=1,
  151. stride=1,
  152. padding=0,
  153. )
  154. self.conv_b = nn.Conv2d(
  155. in_channels,
  156. out_channels,
  157. kernel_size=1,
  158. stride=1,
  159. padding=0,
  160. )
  161. def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor):
  162. quant_states = F.interpolate(quant_states, size=hidden_states.shape[-2:], mode="nearest")
  163. hidden_states = self.norm_layer(hidden_states)
  164. hidden_states = hidden_states * self.conv_y(quant_states) + self.conv_b(quant_states)
  165. return hidden_states
  166. class Emu3VQVAETemporalUpsample(nn.Module):
  167. def __init__(
  168. self,
  169. in_channel: int,
  170. out_channel: int,
  171. ):
  172. super().__init__()
  173. self.conv = Emu3VQVAEConv3d(
  174. in_channel,
  175. out_channel,
  176. kernel_size=(3, 3, 3),
  177. stride=(1, 1, 1),
  178. )
  179. def forward(self, hidden_states: torch.Tensor):
  180. batch_size, channels, temporal, height, width = hidden_states.shape
  181. hidden_states = hidden_states.permute(0, 1, 3, 4, 2).contiguous().view(batch_size, -1, temporal)
  182. hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
  183. hidden_states = hidden_states.view(batch_size, channels, height, width, -1).permute(0, 1, 4, 2, 3).contiguous()
  184. hidden_states = self.conv(hidden_states)
  185. return hidden_states
  186. class Emu3VQVAETemporalDownsample(nn.Module):
  187. def __init__(
  188. self,
  189. in_channel: int,
  190. out_channel: int,
  191. ):
  192. super().__init__()
  193. self.conv = Emu3VQVAEConv3d(
  194. in_channel,
  195. out_channel,
  196. kernel_size=(4, 3, 3),
  197. stride=(2, 1, 1),
  198. )
  199. def forward(self, hidden_states: torch.Tensor):
  200. hidden_states = self.conv(hidden_states)
  201. return hidden_states
  202. class Emu3VQVAETemporalResnetBlock(nn.Module):
  203. def __init__(
  204. self,
  205. in_channels,
  206. out_channels=None,
  207. ):
  208. super().__init__()
  209. self.in_channels = in_channels
  210. self.out_channels = in_channels if out_channels is None else out_channels
  211. self.norm1 = nn.BatchNorm3d(in_channels)
  212. self.conv1 = Emu3VQVAEConv3d(
  213. in_channels,
  214. out_channels,
  215. kernel_size=(3, 3, 3),
  216. stride=(1, 1, 1),
  217. )
  218. self.norm2 = nn.BatchNorm3d(out_channels)
  219. self.conv2 = Emu3VQVAEConv3d(
  220. out_channels,
  221. out_channels,
  222. kernel_size=(3, 3, 3),
  223. stride=(1, 1, 1),
  224. )
  225. if self.in_channels != self.out_channels:
  226. self.nin_shortcut = nn.Conv3d(
  227. in_channels,
  228. out_channels,
  229. kernel_size=1,
  230. stride=1,
  231. padding=0,
  232. )
  233. def forward(self, hidden_states):
  234. residual = hidden_states
  235. hidden_states = self.norm1(hidden_states)
  236. hidden_states *= torch.sigmoid(hidden_states)
  237. hidden_states = self.conv1(hidden_states)
  238. hidden_states = self.norm2(hidden_states)
  239. hidden_states *= torch.sigmoid(hidden_states)
  240. hidden_states = self.conv2(hidden_states)
  241. if self.in_channels != self.out_channels:
  242. residual = self.nin_shortcut(residual)
  243. return residual + hidden_states
  244. class Emu3VQVAEResnetBlock(nn.Module):
  245. def __init__(
  246. self,
  247. in_channels: int,
  248. out_channels: Optional[int] = None,
  249. quant_channels: Optional[int] = None,
  250. ):
  251. super().__init__()
  252. self.in_channels = in_channels
  253. out_channels = in_channels if out_channels is None else out_channels
  254. self.out_channels = out_channels
  255. self.quant_channels = quant_channels
  256. if quant_channels is None:
  257. self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
  258. self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True)
  259. else:
  260. self.norm1 = Emu3VQVAESpatialNorm(quant_channels, in_channels)
  261. self.norm2 = Emu3VQVAESpatialNorm(quant_channels, out_channels)
  262. self.conv1 = nn.Conv2d(
  263. in_channels,
  264. out_channels,
  265. kernel_size=3,
  266. stride=1,
  267. padding=1,
  268. )
  269. self.conv2 = nn.Conv2d(
  270. out_channels,
  271. out_channels,
  272. kernel_size=3,
  273. stride=1,
  274. padding=1,
  275. )
  276. if self.in_channels != self.out_channels:
  277. self.nin_shortcut = nn.Conv2d(
  278. in_channels,
  279. out_channels,
  280. kernel_size=1,
  281. stride=1,
  282. padding=0,
  283. )
  284. def forward(self, hidden_states: torch.Tensor, quant_channels: Optional[torch.Tensor] = None):
  285. norm_args = () if self.quant_channels is None else (quant_channels,)
  286. residual = hidden_states
  287. hidden_states = self.norm1(hidden_states, *norm_args)
  288. hidden_states *= torch.sigmoid(hidden_states)
  289. hidden_states = self.conv1(hidden_states)
  290. hidden_states = self.norm2(hidden_states, *norm_args)
  291. hidden_states *= torch.sigmoid(hidden_states)
  292. hidden_states = self.conv2(hidden_states)
  293. if self.in_channels != self.out_channels:
  294. residual = self.nin_shortcut(residual)
  295. return residual + hidden_states
  296. class Emu3VQVAEAttentionBlock(SiglipAttention):
  297. def __init__(self, config: Emu3VQVAEConfig):
  298. super().__init__(config)
  299. # for compatibility with the attention interface
  300. self.num_key_value_groups = 1
  301. class Emu3VQVAEGroupNorm(nn.GroupNorm):
  302. """
  303. Same as the torch GroupNorm with the only difference that this ones accepts
  304. an optional kwarg `quant_states` which is not used. This class makes it easier to
  305. use SpatialNorm or GroupNorm without conditionals
  306. """
  307. def __init__(self, **kwargs):
  308. super().__init__(**kwargs)
  309. def forward(self, input, quant_states=None):
  310. return F.group_norm(input, self.num_groups, self.weight, self.bias, self.eps)
  311. class Emu3VQVAEMiddleBlock(nn.Module):
  312. def __init__(self, config, in_channels, quant_channels=None):
  313. super().__init__()
  314. self.block_1 = Emu3VQVAEResnetBlock(
  315. in_channels=in_channels,
  316. out_channels=in_channels,
  317. quant_channels=quant_channels,
  318. )
  319. self.attn_1 = Emu3VQVAEAttentionBlock(config)
  320. if quant_channels is None:
  321. self.attn_norm = Emu3VQVAEGroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
  322. else:
  323. self.attn_norm = Emu3VQVAESpatialNorm(quant_channels, in_channels)
  324. self.block_2 = Emu3VQVAEResnetBlock(
  325. in_channels=in_channels,
  326. out_channels=in_channels,
  327. quant_channels=quant_channels,
  328. )
  329. def forward(self, hidden_states: torch.FloatTensor, quant_states: Optional[torch.FloatTensor] = None):
  330. hidden_states = self.block_1(hidden_states, quant_states)
  331. residual = hidden_states
  332. hidden_states = self.attn_norm(hidden_states, quant_states)
  333. batch_size, channels, height, width = hidden_states.shape
  334. hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
  335. hidden_states = self.attn_1(hidden_states)[0]
  336. hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
  337. hidden_states = residual + hidden_states
  338. hidden_states = self.block_2(hidden_states, quant_states)
  339. return hidden_states
  340. class Emu3VQVAEDownBlock(nn.Module):
  341. def __init__(self, config):
  342. super().__init__()
  343. self.num_resolutions = len(config.channel_multiplier)
  344. self.num_res_blocks = config.num_res_blocks
  345. base_channels = config.base_channels
  346. channel_multiplier = config.channel_multiplier
  347. in_channel_multiplier = (1,) + tuple(channel_multiplier)
  348. self.in_channel_multiplier = in_channel_multiplier
  349. self.down = nn.ModuleList()
  350. for i_level in range(self.num_resolutions):
  351. block = nn.ModuleList()
  352. attn = nn.ModuleList()
  353. attn_norms = nn.ModuleList()
  354. block_in = base_channels * in_channel_multiplier[i_level]
  355. block_out = base_channels * channel_multiplier[i_level]
  356. for i_block in range(self.num_res_blocks):
  357. block.append(
  358. Emu3VQVAEResnetBlock(
  359. in_channels=block_in,
  360. out_channels=block_out,
  361. )
  362. )
  363. block_in = block_out
  364. if config.attn_resolutions is not None and i_level in config.attn_resolutions:
  365. attn.append(Emu3VQVAEAttentionBlock(config))
  366. attn_norms.append(nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True))
  367. down = nn.Module()
  368. down.block = block
  369. down.attn = attn
  370. down.attn_norms = attn_norms
  371. if i_level != self.num_resolutions - 1:
  372. down.downsample = Emu3VQVAEEncoderConvDownsample(block_in)
  373. self.down.append(down)
  374. def forward(self, hidden_states: torch.FloatTensor):
  375. for i_level, blocks in enumerate(self.down):
  376. for i_block in range(self.num_res_blocks):
  377. hidden_states = blocks.block[i_block](hidden_states)
  378. if len(blocks.attn) > 0:
  379. residual = hidden_states
  380. hidden_states = blocks.attn_norms[i_block](hidden_states)
  381. batch_size, channels, height, width = hidden_states.shape
  382. hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
  383. hidden_states = blocks.attn[i_block](hidden_states)[0]
  384. hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
  385. hidden_states = residual + hidden_states
  386. if i_level != self.num_resolutions - 1:
  387. hidden_states = blocks.downsample(hidden_states)
  388. return hidden_states
  389. class Emu3VQVAEUpBlock(nn.Module):
  390. def __init__(self, config):
  391. super().__init__()
  392. self.num_resolutions = len(config.channel_multiplier)
  393. self.num_res_blocks = config.num_res_blocks
  394. quant_channels = config.embed_dim
  395. block_in = config.base_channels * config.channel_multiplier[-1]
  396. self.up = nn.ModuleList()
  397. for i_level in reversed(range(self.num_resolutions)):
  398. block = nn.ModuleList()
  399. attn = nn.ModuleList()
  400. attn_norms = nn.ModuleList()
  401. block_out = config.base_channels * config.channel_multiplier[i_level]
  402. for i_block in range(self.num_res_blocks + 1):
  403. block.append(
  404. Emu3VQVAEResnetBlock(
  405. in_channels=block_in,
  406. out_channels=block_out,
  407. quant_channels=quant_channels,
  408. )
  409. )
  410. block_in = block_out
  411. if i_level in config.attn_resolutions:
  412. attn.append(Emu3VQVAEAttentionBlock(config))
  413. attn_norms.append(Emu3VQVAESpatialNorm(quant_channels, block_in))
  414. up = nn.Module()
  415. up.block = block
  416. up.attn = attn
  417. up.attn_norms = attn_norms
  418. if i_level != 0:
  419. up.upsample = Emu3VQVAEEncoderConvUpsample(block_in)
  420. self.up.insert(0, up)
  421. def forward(self, hidden_states: torch.FloatTensor, quant_states: torch.FloatTensor):
  422. for i_level, blocks in enumerate(self.up[::-1]):
  423. for i_block in range(self.num_res_blocks + 1):
  424. hidden_states = blocks.block[i_block](hidden_states, quant_states)
  425. if len(blocks.attn) > 0:
  426. residual = hidden_states
  427. hidden_states = blocks.attn_norms[i_block](hidden_states, quant_states)
  428. batch_size, channels, height, width = hidden_states.shape
  429. hidden_states = hidden_states.view(batch_size, channels, height * width).transpose(1, 2)
  430. hidden_states = blocks.attn[i_block](hidden_states)[0]
  431. hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
  432. hidden_states = residual + hidden_states
  433. if i_level != len(self.up) - 1:
  434. hidden_states = blocks.upsample(hidden_states)
  435. return hidden_states
  436. class Emu3VQVAEEncoder(nn.Module):
  437. def __init__(self, config):
  438. super().__init__()
  439. base_channels = config.base_channels
  440. in_channels = config.in_channels
  441. double_latent = config.double_latent
  442. latent_channels = config.latent_channels
  443. channel_multiplier = config.channel_multiplier
  444. out_channels = 2 * latent_channels if double_latent else latent_channels
  445. block_in = base_channels * channel_multiplier[-1]
  446. self.conv_in = torch.nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)
  447. self.down_block = Emu3VQVAEDownBlock(config)
  448. self.middle_block = Emu3VQVAEMiddleBlock(config, block_in)
  449. self.norm_out = torch.nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
  450. self.conv_out = torch.nn.Conv2d(
  451. block_in,
  452. out_channels,
  453. kernel_size=3,
  454. stride=1,
  455. padding=1,
  456. )
  457. temporal_down_blocks = int(math.log2(config.temporal_downsample_factor))
  458. self.time_conv = nn.ModuleList()
  459. self.time_res_stack = nn.ModuleList()
  460. for i in range(temporal_down_blocks):
  461. conv = Emu3VQVAETemporalDownsample(out_channels, out_channels)
  462. self.time_conv.append(conv)
  463. for _ in range(config.num_res_blocks):
  464. time_res_conv = Emu3VQVAETemporalResnetBlock(
  465. in_channels=out_channels,
  466. out_channels=out_channels,
  467. )
  468. self.time_res_stack.append(time_res_conv)
  469. def forward(self, pixel_values: torch.LongTensor):
  470. temporal_dim = pixel_values.shape[1]
  471. pixel_values = pixel_values.reshape(-1, *pixel_values.shape[2:])
  472. # downsampling & middle
  473. hidden_states = self.conv_in(pixel_values)
  474. hidden_states = self.down_block(hidden_states)
  475. hidden_states = self.middle_block(hidden_states)
  476. # end
  477. hidden_states = self.norm_out(hidden_states)
  478. hidden_states *= torch.sigmoid(hidden_states)
  479. hidden_states = self.conv_out(hidden_states)
  480. hidden_states = hidden_states.reshape(-1, temporal_dim, *hidden_states.shape[1:])
  481. hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
  482. # temporal convs
  483. for conv in self.time_conv:
  484. hidden_states = conv(hidden_states)
  485. hidden_states *= torch.sigmoid(hidden_states)
  486. for layer in self.time_res_stack:
  487. hidden_states = layer(hidden_states)
  488. hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
  489. return hidden_states
  490. class Emu3VQVAEDecoder(nn.Module):
  491. def __init__(self, config: Emu3VQVAEConfig):
  492. super().__init__()
  493. quant_channels = config.embed_dim
  494. block_in = config.base_channels * config.channel_multiplier[-1]
  495. self.time_res_stack = nn.ModuleList()
  496. for _ in range(config.num_res_blocks):
  497. time_res_conv = Emu3VQVAETemporalResnetBlock(
  498. in_channels=config.latent_channels, out_channels=config.latent_channels
  499. )
  500. self.time_res_stack.append(time_res_conv)
  501. temp_upsample_block_num = int(math.log2(config.temporal_downsample_factor))
  502. self.time_conv = nn.ModuleList()
  503. for i in range(temp_upsample_block_num):
  504. conv = Emu3VQVAETemporalUpsample(config.latent_channels, config.latent_channels)
  505. self.time_conv.append(conv)
  506. self.conv_in = nn.Conv2d(
  507. config.latent_channels,
  508. block_in,
  509. kernel_size=3,
  510. stride=1,
  511. padding=1,
  512. )
  513. self.middle_block = Emu3VQVAEMiddleBlock(config, block_in, quant_channels=quant_channels)
  514. self.up_block = Emu3VQVAEUpBlock(config)
  515. block_in = config.base_channels * config.channel_multiplier[0]
  516. self.norm_out = Emu3VQVAESpatialNorm(quant_channels, block_in)
  517. self.conv_out = nn.Conv2d(
  518. block_in,
  519. config.out_channels,
  520. kernel_size=3,
  521. stride=1,
  522. padding=1,
  523. )
  524. def forward(self, hidden_states: torch.Tensor, quant_states: torch.Tensor):
  525. hidden_quant_states = torch.cat((hidden_states, quant_states), dim=0)
  526. hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4)
  527. # temporal convs
  528. for layer in self.time_res_stack:
  529. hidden_quant_states = layer(hidden_quant_states)
  530. for layer in self.time_conv:
  531. hidden_quant_states = layer(hidden_quant_states)
  532. hidden_quant_states *= torch.sigmoid(hidden_quant_states)
  533. hidden_quant_states = hidden_quant_states.permute(0, 2, 1, 3, 4)
  534. hidden_states, quant_states = torch.chunk(hidden_quant_states, 2, dim=0)
  535. hidden_states = hidden_states.reshape(-1, *hidden_states.shape[2:])
  536. quant_states = quant_states.reshape(-1, *quant_states.shape[2:])
  537. hidden_states = self.conv_in(hidden_states)
  538. # middle & upsampling
  539. hidden_states = self.middle_block(hidden_states, quant_states)
  540. hidden_states = self.up_block(hidden_states, quant_states)
  541. hidden_states = self.norm_out(hidden_states, quant_states)
  542. hidden_states *= torch.sigmoid(hidden_states)
  543. hidden_states = self.conv_out(hidden_states)
  544. return hidden_states
  545. @auto_docstring(
  546. custom_intro="""
  547. The VQ-VAE model used in Emu3 for encoding/decoding images into discrete tokens.
  548. This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
  549. [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
  550. Taigman](https://huggingface.co/papers/2203.13131).
  551. """
  552. )
  553. class Emu3VQVAE(PreTrainedModel):
  554. config: Emu3VQVAEConfig
  555. base_model_prefix = "emuvideovq"
  556. main_input_name = "pixel_values"
  557. _supports_sdpa = True
  558. _supports_flash_attn = True
  559. _supports_flex_attn = True
  560. _supports_attention_backend = True
  561. _no_split_modules = [
  562. "Emu3VQVAETemporalResnetBlock",
  563. "Emu3VQVAEAttentionBlock",
  564. "Emu3VQVAEResnetBlock",
  565. "Emu3VQVAEVectorQuantizer",
  566. ]
  567. def _init_weights(self, module):
  568. if isinstance(module, (nn.Conv2d, nn.Conv3d)):
  569. nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
  570. if module.bias is not None:
  571. fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
  572. bound = 1 / math.sqrt(fan_in)
  573. nn.init.uniform_(module.bias, -bound, bound)
  574. elif isinstance(module, nn.Linear):
  575. nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
  576. if module.bias is not None:
  577. fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
  578. bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
  579. nn.init.uniform_(module.bias, -bound, bound)
  580. elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
  581. nn.init.constant_(module.weight, 1.0)
  582. nn.init.constant_(module.bias, 0.0)
  583. elif isinstance(module, nn.Embedding):
  584. module.weight.data.normal_()
  585. if module.padding_idx is not None:
  586. module.weight.data[module.padding_idx].zero_()
  587. def __init__(self, config: Emu3VQVAEConfig):
  588. super().__init__(config)
  589. self.config = config
  590. self.encoder = Emu3VQVAEEncoder(config)
  591. self.decoder = Emu3VQVAEDecoder(config)
  592. self.quantize = Emu3VQVAEVectorQuantizer(config)
  593. self.vision_spatial_factor = 2 ** (len(config.channel_multiplier) - 1)
  594. self.quant_conv = Emu3VQVAEConv3d(
  595. config.latent_channels, config.embed_dim, kernel_size=(3, 1, 1), stride=(1, 1, 1)
  596. )
  597. self.post_quant_conv = Emu3VQVAEConv3d(
  598. config.embed_dim, config.latent_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1)
  599. )
  600. self.spatial_scale_factor = 2 ** (len(config.channel_multiplier) - 1)
  601. self.eval() # Emu3's VQ model is frozen
  602. self.post_init()
  603. def encode(self, pixel_values: torch.Tensor, image_sizes: torch.Tensor):
  604. is_image = pixel_values.ndim == 4
  605. if is_image:
  606. temporal = self.config.temporal_downsample_factor
  607. batch_size, channels, height, width = pixel_values.shape
  608. pixel_values = pixel_values.unsqueeze(1).repeat(1, temporal, 1, 1, 1)
  609. else:
  610. batch_size, temporal, channels, height, width = pixel_values.shape
  611. hidden_states = self.encoder(pixel_values)
  612. # b t c h w -> b c t h w
  613. hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
  614. hidden_states = self.quant_conv(hidden_states)
  615. # b c t h w -> b t c h w
  616. hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
  617. codes = self.quantize(hidden_states)
  618. image_tokens = codes.squeeze(1) if is_image else codes
  619. image_tokens = [
  620. single_image[: int(size[0] / self.vision_spatial_factor), : int(size[1] / self.vision_spatial_factor)]
  621. for single_image, size in zip(image_tokens, image_sizes)
  622. ]
  623. return image_tokens
  624. def decode(self, hidden_states: torch.Tensor):
  625. is_image = hidden_states.ndim == 3
  626. if is_image:
  627. hidden_states = hidden_states.unsqueeze(1)
  628. batch_size, temporal, height, width = hidden_states.shape
  629. quant = self.quantize.embedding(hidden_states.flatten())
  630. channels = quant.shape[-1]
  631. quant = quant.view(batch_size, temporal, height, width, channels).permute(0, 4, 1, 2, 3).contiguous()
  632. post_quant = self.post_quant_conv(quant)
  633. quant = quant.permute(0, 2, 1, 3, 4)
  634. post_quant = post_quant.permute(0, 2, 1, 3, 4)
  635. video = self.decoder(post_quant, quant)
  636. video = video.reshape(
  637. batch_size,
  638. temporal * self.config.temporal_downsample_factor,
  639. self.config.out_channels,
  640. height * self.spatial_scale_factor,
  641. width * self.spatial_scale_factor,
  642. )
  643. return video[:, 0] if is_image else video
  644. class Emu3ImageVocabularyMapping:
  645. """
  646. A class for mapping discrete image tokens from VQGAN to BPE tokens.
  647. """
  648. def __init__(self, vocab_map):
  649. self.vocab_map = vocab_map
  650. self.eol_token_id = vocab_map.get("<|extra_200|>")
  651. self.image_token_id = vocab_map.get("<image>")
  652. @cached_property
  653. def image_tokens(self):
  654. return sorted([val for name, val in self.vocab_map.items() if name.startswith("<|visual token")])
  655. @cached_property
  656. def image_tokens_str(self):
  657. return sorted([name for name, val in self.vocab_map.items() if name.startswith("<|visual token")])
  658. @cached_property
  659. def img2bpe(self):
  660. return {int(token[-8:-2]): self.vocab_map[token] for token in self.image_tokens_str}
  661. @cached_property
  662. def bpe2img(self):
  663. return {v: k for k, v in self.img2bpe.items()}
  664. @cached_property
  665. def bpe2img_mapping_tensor(self):
  666. mapping = torch.zeros(max(self.bpe2img.keys()) + 1, dtype=torch.int)
  667. for k, v in self.bpe2img.items():
  668. mapping[k] = v
  669. return mapping
  670. @cached_property
  671. def img2bpe_mapping_tensor(self):
  672. mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int)
  673. for k, v in self.img2bpe.items():
  674. mapping[k] = v
  675. return mapping
  676. def convert_img2bpe(self, img_batch: list[torch.Tensor]) -> torch.Tensor:
  677. device = img_batch.device
  678. eol_row = torch.ones((img_batch.shape[0], 1), dtype=torch.int) * self.eol_token_id
  679. img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")]
  680. img_tokens = torch.cat([img_tokens, eol_row], dim=-1)
  681. return img_tokens.to(device)
  682. def convert_bpe2img(self, img_batch: torch.Tensor) -> torch.Tensor:
  683. device = img_batch.device
  684. img_batch = img_batch[..., :-1] # remove last row of EOL tokens
  685. img_tokens = self.bpe2img_mapping_tensor[img_batch.to("cpu")]
  686. return img_tokens.to(device)
  687. class Emu3PreTrainedModel(ChameleonPreTrainedModel, Emu3VQVAE):
  688. _no_split_modules = [
  689. "Emu3DecoderLayer",
  690. ]
  691. _supports_flex_attn = True
  692. _supports_attention_backend = True
  693. class Emu3TextModel(LlamaModel, Emu3PreTrainedModel):
  694. _can_record_outputs = {
  695. "hidden_states": Emu3DecoderLayer,
  696. "attentions": Emu3Attention,
  697. }
  698. def __init__(self, config: Emu3Config):
  699. super().__init__(config)
  700. self.layers = nn.ModuleList(
  701. [Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  702. )
  703. class Emu3ForCausalLM(LlamaForCausalLM, Emu3PreTrainedModel, GenerationMixin):
  704. config: Emu3TextConfig
  705. def __init__(self, config):
  706. super().__init__(config)
  707. self.model = Emu3TextModel(config)
  708. def forward(**super_kwargs):
  709. r"""
  710. Example:
  711. ```python
  712. >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration
  713. >>> import torch
  714. >>> import requests
  715. >>> from PIL import Image
  716. >>> model = Emu3ForCausalLM.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16)
  717. >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf")
  718. >>> inputs = processor(text=["Can you write me a poem about winter."], return_tensors="pt").to(model.device)
  719. >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
  720. >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  721. ```"""
  722. super().forward()
  723. class Emu3Model(Emu3PreTrainedModel):
  724. _checkpoint_conversion_mapping = {"text_model.model": "text_model"}
  725. def __init__(self, config):
  726. super().__init__(config)
  727. self.text_model = Emu3TextModel._from_config(config.text_config)
  728. self.vqmodel = Emu3VQVAE(config.vq_config)
  729. self.vocabulary_mapping = Emu3ImageVocabularyMapping(config.vocabulary_map)
  730. # Initialize weights and apply final processing
  731. self.post_init()
  732. def get_input_embeddings(self):
  733. return self.text_model.get_input_embeddings()
  734. def set_input_embeddings(self, value):
  735. self.text_model.set_input_embeddings(value)
  736. def set_decoder(self, decoder):
  737. self.text_model = decoder
  738. def get_decoder(self):
  739. return self.text_model
  740. def get_image_tokens(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
  741. """
  742. Tokenizes images into discrete tokens with VQGAN module. Converts
  743. obtained image tokens into BPE tokens and wraps with "boi" and "eoi"
  744. special tokens.
  745. Args:
  746. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  747. The tensors corresponding to the input images.
  748. image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
  749. The sizes of the images in the batch, being (height, width) for each image.
  750. """
  751. image_tokens_list = self.vqmodel.encode(pixel_values, image_sizes)
  752. bpe_tokens_list = [self.vocabulary_mapping.convert_img2bpe(tokens).flatten() for tokens in image_tokens_list]
  753. bpe_tokens = torch.cat(bpe_tokens_list)
  754. return bpe_tokens
  755. def get_image_features(self, pixel_values: torch.FloatTensor, image_sizes: torch.LongTensor):
  756. """
  757. Tokenizes images into discrete tokens with VQGAN module and embeds
  758. them with text embeddings layer
  759. Args:
  760. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
  761. The tensors corresponding to the input images.
  762. """
  763. image_tokens = self.get_image_tokens(pixel_values, image_sizes)
  764. split_sizes = [
  765. (height // self.vqmodel.vision_spatial_factor) * (width // self.vqmodel.vision_spatial_factor + 1)
  766. for height, width in image_sizes
  767. ]
  768. image_features = self.get_input_embeddings()(image_tokens)
  769. image_features = torch.split(image_features, split_sizes)
  770. return image_features
  771. @torch.no_grad
  772. def decode_image_tokens(self, image_tokens: torch.LongTensor, height: int, width: int):
  773. """
  774. Decodes generated image tokens from language model to continuous pixel values
  775. with VQGAN module via upsampling.
  776. Args:
  777. image_tokens (`torch.LongTensor` of shape `(batch_size, num_of_tokens)`):
  778. The tensors corresponding to the input images.
  779. height (`int`):
  780. Height of the generated image before upsampling.
  781. width (`int`):
  782. Width of the generated image before upsampling.
  783. """
  784. sequences = image_tokens[:, :-3].view(-1, height, width + 1)
  785. image_tokens = self.vocabulary_mapping.convert_bpe2img(sequences)
  786. image = self.vqmodel.decode(image_tokens)
  787. return image
  788. def get_placeholder_mask(
  789. self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor
  790. ):
  791. """
  792. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  793. equal to the length of multimodal features. If the lengths are different, an error is raised.
  794. """
  795. if input_ids is None:
  796. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  797. torch.tensor(self.vocabulary_mapping.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  798. )
  799. special_image_mask = special_image_mask.all(-1)
  800. else:
  801. special_image_mask = input_ids == self.vocabulary_mapping.image_token_id
  802. n_image_tokens = special_image_mask.sum()
  803. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  804. n_image_features = image_features.shape[0] * image_features.shape[1]
  805. if inputs_embeds[special_image_mask].numel() != image_features.numel():
  806. raise ValueError(
  807. f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
  808. )
  809. return special_image_mask
  810. @can_return_tuple
  811. @auto_docstring
  812. def forward(
  813. self,
  814. input_ids: Optional[torch.LongTensor] = None,
  815. pixel_values: Optional[torch.FloatTensor] = None,
  816. image_sizes: Optional[torch.Tensor] = None,
  817. attention_mask: Optional[torch.Tensor] = None,
  818. position_ids: Optional[torch.LongTensor] = None,
  819. past_key_values: Optional[Cache] = None,
  820. inputs_embeds: Optional[torch.FloatTensor] = None,
  821. use_cache: Optional[bool] = None,
  822. cache_position: Optional[torch.LongTensor] = None,
  823. **kwargs: Unpack[TransformersKwargs],
  824. ) -> Union[tuple, CausalLMOutputWithPast]:
  825. r"""
  826. image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
  827. The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using
  828. [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
  829. [`Emu3ImageProcessor`] for processing images).
  830. """
  831. if (input_ids is None) ^ (inputs_embeds is not None):
  832. raise ValueError(
  833. "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
  834. )
  835. if inputs_embeds is None:
  836. inputs_embeds = self.get_input_embeddings()(input_ids)
  837. if pixel_values is not None:
  838. image_embeds = self.get_image_features(pixel_values, image_sizes)
  839. image_embeds = torch.cat(image_embeds, dim=0)
  840. special_image_mask = self.get_placeholder_mask(
  841. input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
  842. )
  843. inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_embeds)
  844. # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
  845. outputs = self.text_model(
  846. attention_mask=attention_mask,
  847. position_ids=position_ids,
  848. past_key_values=past_key_values,
  849. inputs_embeds=inputs_embeds,
  850. use_cache=use_cache,
  851. cache_position=cache_position,
  852. **kwargs,
  853. )
  854. return outputs
  855. class Emu3ForConditionalGeneration(Emu3PreTrainedModel, GenerationMixin):
  856. base_model_prefix = ""
  857. _tied_weights_keys = ["lm_head.weight"]
  858. _checkpoint_conversion_mapping = {
  859. "^text_model.model": "model.text_model",
  860. "^vqmodel": "model.vqmodel",
  861. "^text_model.lm_head": "lm_head",
  862. }
  863. def __init__(self, config):
  864. super().__init__(config)
  865. self.model = Emu3Model(config)
  866. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  867. self.post_init()
  868. def get_input_embeddings(self):
  869. return self.model.get_input_embeddings()
  870. def set_input_embeddings(self, value):
  871. self.model.set_input_embeddings(value)
  872. def get_output_embeddings(self) -> nn.Module:
  873. return self.lm_head
  874. def set_decoder(self, decoder):
  875. self.model.set_decoder(decoder)
  876. def get_decoder(self):
  877. return self.model.get_decoder()
  878. # Make modules available through conditional class for BC
  879. @property
  880. def text_model(self):
  881. return self.model.text_model
  882. @property
  883. def vqmodel(self):
  884. return self.model.vqmodel
  885. @property
  886. def vocabulary_mapping(self):
  887. return self.model.vocabulary_mapping
  888. def decode_image_tokens(self, **kwargs):
  889. return self.model.decode_image_tokens(**kwargs)
  890. @can_return_tuple
  891. @auto_docstring
  892. def forward(
  893. self,
  894. input_ids: Optional[torch.LongTensor] = None,
  895. pixel_values: Optional[torch.FloatTensor] = None,
  896. image_sizes: Optional[torch.Tensor] = None,
  897. attention_mask: Optional[torch.Tensor] = None,
  898. position_ids: Optional[torch.LongTensor] = None,
  899. past_key_values: Optional[Cache] = None,
  900. inputs_embeds: Optional[torch.FloatTensor] = None,
  901. use_cache: Optional[bool] = None,
  902. cache_position: Optional[torch.LongTensor] = None,
  903. labels: Optional[torch.LongTensor] = None,
  904. logits_to_keep: Union[int, torch.Tensor] = 0,
  905. **kwargs: Unpack[TransformersKwargs],
  906. ) -> Union[tuple, CausalLMOutputWithPast]:
  907. r"""
  908. image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`):
  909. The sizes of the images in the batch, being (height, width) for each image. Image sizes can be obtained using
  910. [`AutoImageProcessor`]. See [`Emu3ImageProcessor.__call__`] for details ([]`Emu3Processor`] uses
  911. [`Emu3ImageProcessor`] for processing images).
  912. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  913. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  914. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  915. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  916. Example:
  917. ```python
  918. >>> from transformers import Emu3Processor, Emu3ForConditionalGeneration
  919. >>> import torch
  920. >>> import requests
  921. >>> from PIL import Image
  922. >>> model = Emu3ForConditionalGeneration.from_pretrained("BAAI/Emu3-Chat-hf", dtype=torch.bfloat16)
  923. >>> processor = Emu3Processor.from_pretrained("BAAI/Emu3-Chat-hf")
  924. >>> conversation = [
  925. ... {
  926. ... "role": "system",
  927. ... "content": [
  928. ... {"type": "text", "text": "You are a helpful assistant."},
  929. ... ],
  930. ... },
  931. ... {
  932. ... "role": "user",
  933. ... "content": [
  934. ... {"type": "image"},
  935. ... {"type": "text", "text": "Please describe the image."},
  936. ... ],
  937. ... },
  938. ... ]
  939. >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
  940. >>> image = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
  941. >>> inputs = processor(images=[image], text=[prompt], return_tensors="pt").to(model.device, torch.bfloat16)
  942. >>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
  943. >>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  944. ```"""
  945. outputs = self.model(
  946. input_ids=input_ids,
  947. attention_mask=attention_mask,
  948. position_ids=position_ids,
  949. past_key_values=past_key_values,
  950. inputs_embeds=inputs_embeds,
  951. use_cache=use_cache,
  952. cache_position=cache_position,
  953. **kwargs,
  954. )
  955. hidden_states = outputs[0]
  956. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  957. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  958. logits = self.lm_head(hidden_states[:, slice_indices, :])
  959. loss = None
  960. if labels is not None:
  961. loss = self.loss_function(
  962. logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
  963. )
  964. return CausalLMOutputWithPast(
  965. loss=loss,
  966. logits=logits,
  967. past_key_values=outputs.past_key_values,
  968. hidden_states=outputs.hidden_states,
  969. attentions=outputs.attentions,
  970. )
  971. def prepare_inputs_for_generation(
  972. self,
  973. input_ids,
  974. past_key_values=None,
  975. attention_mask=None,
  976. inputs_embeds=None,
  977. cache_position=None,
  978. position_ids=None,
  979. use_cache=True,
  980. pixel_values=None,
  981. **kwargs,
  982. ):
  983. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  984. model_inputs = super().prepare_inputs_for_generation(
  985. input_ids,
  986. past_key_values=past_key_values,
  987. attention_mask=attention_mask,
  988. inputs_embeds=inputs_embeds,
  989. cache_position=cache_position,
  990. position_ids=position_ids,
  991. pixel_values=pixel_values,
  992. use_cache=use_cache,
  993. **kwargs,
  994. )
  995. if cache_position[0] != 0:
  996. model_inputs["pixel_values"] = None
  997. return model_inputs
  998. __all__ = [
  999. "Emu3ForConditionalGeneration",
  1000. "Emu3ForCausalLM",
  1001. "Emu3TextModel",
  1002. "Emu3PreTrainedModel",
  1003. "Emu3VQVAE",
  1004. "Emu3Model",
  1005. ]