modeling_pix2struct.py 70 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. & Google team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Pix2Struct modeling file"""
  16. import math
  17. from typing import Optional, Union
  18. import torch
  19. from torch import nn
  20. from ...activations import ACT2FN
  21. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  22. from ...generation import GenerationMixin
  23. from ...modeling_attn_mask_utils import AttentionMaskConverter
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import (
  26. BaseModelOutput,
  27. BaseModelOutputWithPooling,
  28. CausalLMOutputWithCrossAttentions,
  29. Seq2SeqLMOutput,
  30. Seq2SeqModelOutput,
  31. )
  32. from ...modeling_utils import PreTrainedModel
  33. from ...utils import (
  34. DUMMY_INPUTS,
  35. DUMMY_MASK,
  36. auto_docstring,
  37. is_torch_flex_attn_available,
  38. is_torch_fx_proxy,
  39. is_torchdynamo_compiling,
  40. logging,
  41. )
  42. from ...utils.deprecation import deprecate_kwarg
  43. from .configuration_pix2struct import Pix2StructConfig, Pix2StructTextConfig, Pix2StructVisionConfig
  44. if is_torch_flex_attn_available():
  45. from torch.nn.attention.flex_attention import BlockMask
  46. from ...integrations.flex_attention import make_flex_block_causal_mask
  47. logger = logging.get_logger(__name__)
  48. # General docstring
  49. # Adapted from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Pix2Struct
  50. class Pix2StructLayerNorm(nn.Module):
  51. def __init__(self, hidden_size, eps=1e-6):
  52. """
  53. Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
  54. """
  55. super().__init__()
  56. self.weight = nn.Parameter(torch.ones(hidden_size))
  57. self.variance_epsilon = eps
  58. def forward(self, hidden_states):
  59. # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
  60. # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
  61. # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
  62. # half-precision inputs is done in fp32
  63. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  64. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  65. # convert into half-precision if necessary
  66. if self.weight.dtype in [torch.float16, torch.bfloat16]:
  67. hidden_states = hidden_states.to(self.weight.dtype)
  68. return self.weight * hidden_states
  69. try:
  70. from apex.normalization import FusedRMSNorm
  71. Pix2StructLayerNorm = FusedRMSNorm
  72. logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of Pix2StructLayerNorm")
  73. except ImportError:
  74. # using the normal Pix2StructLayerNorm
  75. pass
  76. except Exception:
  77. logger.warning("Discovered apex but it failed to load, falling back to Pix2StructLayerNorm")
  78. pass
  79. class Pix2StructVisionEmbeddings(nn.Module):
  80. r"""
  81. Construct the embeddings from patch. In `Pix2Struct` the input is different from classic Vision-transformer models.
  82. Here the input is a sequence of `seq_len` flattened patches that also combines padding patches (tokens). Each patch
  83. is represented by a vector of `hidden_size` values.
  84. """
  85. def __init__(self, config: Pix2StructConfig) -> None:
  86. super().__init__()
  87. self.patch_projection = nn.Linear(config.patch_embed_hidden_size, config.hidden_size)
  88. self.row_embedder = nn.Embedding(config.seq_len, config.hidden_size)
  89. self.column_embedder = nn.Embedding(config.seq_len, config.hidden_size)
  90. self.dropout = nn.Dropout(config.dropout_rate)
  91. def forward(self, flattened_patches: torch.Tensor) -> torch.Tensor:
  92. # the row and column indices are stored in the first and second position of the flattened_patches
  93. # flattened_patches: `batch_size`, `seq_len`, `hidden_size` + 2
  94. row_indices = flattened_patches[:, :, 0].long()
  95. col_indices = flattened_patches[:, :, 1].long()
  96. flattened_patches = flattened_patches[:, :, 2:]
  97. embeddings = self.patch_projection(flattened_patches)
  98. row_embeddings = self.row_embedder(row_indices)
  99. col_embeddings = self.column_embedder(col_indices)
  100. # sum all embeddings together
  101. embeddings = embeddings + row_embeddings + col_embeddings
  102. embeddings = self.dropout(embeddings)
  103. return embeddings
  104. class Pix2StructVisionAttention(nn.Module):
  105. def __init__(self, config):
  106. super().__init__()
  107. self.hidden_size = config.hidden_size
  108. self.key_value_proj_dim = config.d_kv
  109. self.n_heads = config.num_attention_heads
  110. self.dropout = config.attention_dropout
  111. self.inner_dim = self.n_heads * self.key_value_proj_dim
  112. # Mesh TensorFlow initialization to avoid scaling before softmax
  113. self.query = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
  114. self.key = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
  115. self.value = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
  116. self.output = nn.Linear(self.inner_dim, self.hidden_size, bias=False)
  117. self.gradient_checkpointing = False
  118. def forward(
  119. self,
  120. hidden_states,
  121. attention_mask=None,
  122. position_bias=None,
  123. layer_head_mask=None,
  124. output_attentions=False,
  125. ):
  126. """
  127. Self-attention block
  128. """
  129. # Input is (batch_size, seq_length, dim)
  130. # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
  131. # past_key_values[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
  132. batch_size, seq_length = hidden_states.shape[:2]
  133. def to_projection_shape(states):
  134. """projection"""
  135. return states.contiguous().view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  136. # get query states
  137. # (batch_size, n_heads, seq_length, dim_per_head)
  138. query_states = to_projection_shape(self.query(hidden_states))
  139. # get key/value states
  140. key_states = to_projection_shape(self.key(hidden_states))
  141. value_states = to_projection_shape(self.value(hidden_states))
  142. # compute scores
  143. # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  144. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  145. if position_bias is None:
  146. position_bias = torch.zeros(
  147. (1, self.n_heads, seq_length, seq_length), device=scores.device, dtype=scores.dtype
  148. )
  149. if self.gradient_checkpointing and self.training:
  150. position_bias.requires_grad = True
  151. if attention_mask.dim() == 2:
  152. position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device)
  153. elif attention_mask is not None:
  154. # (batch_size, n_heads, seq_length, key_length)
  155. position_bias = position_bias + attention_mask.to(position_bias.device)
  156. elif not is_torchdynamo_compiling():
  157. attention_mask = torch.ones(
  158. (batch_size, seq_length), device=position_bias.device, dtype=position_bias.dtype
  159. )
  160. position_bias = position_bias + attention_mask.to(position_bias.device)
  161. position_bias = 1 - position_bias
  162. position_bias_masked = position_bias.masked_fill(position_bias == 1, torch.finfo(scores.dtype).min)
  163. scores += position_bias_masked
  164. scores = torch.max(scores, torch.tensor(torch.finfo(scores.dtype).min))
  165. # (batch_size, n_heads, seq_length, key_length)
  166. attn_weights = nn.functional.softmax(scores, dim=-1, dtype=torch.float32).type_as(scores)
  167. # (batch_size, n_heads, seq_length, key_length)
  168. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  169. # Mask heads if we want to
  170. if layer_head_mask is not None:
  171. attn_weights = attn_weights * layer_head_mask
  172. attn_output = torch.matmul(attn_weights, value_states)
  173. # (batch_size, seq_length, dim)
  174. attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
  175. attn_output = self.output(attn_output)
  176. outputs = (attn_output,) + (position_bias,)
  177. if output_attentions:
  178. outputs = outputs + (attn_weights,)
  179. return outputs
  180. # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5DenseGatedActDense->Pix2StructVisionMlp,T5Config->Pix2StructVisionConfig,config.d_model->config.hidden_size,dropout_rate->dropout_rate
  181. class Pix2StructVisionMlp(nn.Module):
  182. def __init__(self, config: Pix2StructVisionConfig):
  183. super().__init__()
  184. self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
  185. self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
  186. self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False)
  187. self.dropout = nn.Dropout(config.dropout_rate)
  188. self.act = ACT2FN[config.dense_act_fn]
  189. def forward(self, hidden_states):
  190. hidden_gelu = self.act(self.wi_0(hidden_states))
  191. hidden_linear = self.wi_1(hidden_states)
  192. hidden_states = hidden_gelu * hidden_linear
  193. hidden_states = self.dropout(hidden_states)
  194. # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
  195. # See https://github.com/huggingface/transformers/issues/20287
  196. # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
  197. if (
  198. isinstance(self.wo.weight, torch.Tensor)
  199. and hidden_states.dtype != self.wo.weight.dtype
  200. and self.wo.weight.dtype != torch.int8
  201. ):
  202. hidden_states = hidden_states.to(self.wo.weight.dtype)
  203. hidden_states = self.wo(hidden_states)
  204. return hidden_states
  205. class Pix2StructVisionLayer(GradientCheckpointingLayer):
  206. def __init__(self, config: Pix2StructConfig) -> None:
  207. super().__init__()
  208. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  209. self.seq_len_dim = 1
  210. self.attention = Pix2StructVisionAttention(config)
  211. self.mlp = Pix2StructVisionMlp(config)
  212. self.pre_mlp_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  213. self.pre_attention_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  214. def forward(
  215. self,
  216. hidden_states: torch.Tensor,
  217. attention_mask: Optional[torch.Tensor] = None,
  218. head_mask: Optional[torch.Tensor] = None,
  219. output_attentions: bool = False,
  220. ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]:
  221. residual = hidden_states
  222. # in Pix2StructVision, layernorm is applied before self-attention
  223. hidden_states = self.pre_attention_layer_norm(hidden_states)
  224. self_attention_outputs = self.attention(
  225. hidden_states,
  226. attention_mask=attention_mask,
  227. layer_head_mask=head_mask,
  228. output_attentions=output_attentions,
  229. )
  230. attention_output = self_attention_outputs[0]
  231. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  232. # first residual connection
  233. hidden_states = attention_output + residual
  234. # in Pix2StructVision, layernorm is also applied after self-attention
  235. layer_output = self.pre_mlp_layer_norm(hidden_states)
  236. layer_output = self.mlp(layer_output) + hidden_states # second residual connection
  237. outputs = (layer_output,) + outputs
  238. return outputs
  239. class Pix2StructVisionEncoder(nn.Module):
  240. def __init__(self, config: Pix2StructVisionConfig) -> None:
  241. super().__init__()
  242. self.config = config
  243. self.layer = nn.ModuleList([Pix2StructVisionLayer(config) for _ in range(config.num_hidden_layers)])
  244. self.gradient_checkpointing = False
  245. def forward(
  246. self,
  247. hidden_states: torch.Tensor,
  248. attention_mask: Optional[torch.Tensor] = None,
  249. head_mask: Optional[torch.Tensor] = None,
  250. output_attentions: bool = False,
  251. output_hidden_states: bool = False,
  252. return_dict: bool = True,
  253. ) -> Union[tuple, BaseModelOutput]:
  254. all_hidden_states = () if output_hidden_states else None
  255. all_self_attentions = () if output_attentions else None
  256. for i, layer_module in enumerate(self.layer):
  257. if output_hidden_states:
  258. all_hidden_states = all_hidden_states + (hidden_states,)
  259. layer_head_mask = head_mask[i] if head_mask is not None else None
  260. layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
  261. hidden_states = layer_outputs[0]
  262. if output_attentions:
  263. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  264. if output_hidden_states:
  265. all_hidden_states = all_hidden_states + (hidden_states,)
  266. if not return_dict:
  267. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  268. return BaseModelOutput(
  269. last_hidden_state=hidden_states,
  270. hidden_states=all_hidden_states,
  271. attentions=all_self_attentions,
  272. )
  273. @auto_docstring
  274. class Pix2StructPreTrainedModel(PreTrainedModel):
  275. config: Pix2StructConfig
  276. _can_compile_fullgraph = False
  277. @property
  278. def dummy_inputs(self):
  279. input_ids = torch.tensor(DUMMY_INPUTS)
  280. input_mask = torch.tensor(DUMMY_MASK)
  281. dummy_inputs = {
  282. "decoder_input_ids": input_ids,
  283. "input_ids": input_ids,
  284. "decoder_attention_mask": input_mask,
  285. }
  286. return dummy_inputs
  287. def _init_weights(self, module):
  288. """Initialize the weights"""
  289. factor = self.config.initializer_factor # Used for testing weights initialization
  290. if isinstance(module, Pix2StructLayerNorm):
  291. module.weight.data.fill_(factor * 1.0)
  292. elif isinstance(module, Pix2StructTextDenseGatedActDense):
  293. hidden_size = (
  294. self.config.text_config.hidden_size
  295. if isinstance(self.config, Pix2StructConfig)
  296. else self.config.hidden_size
  297. )
  298. d_ff = self.config.text_config.d_ff if isinstance(self.config, Pix2StructConfig) else self.config.d_ff
  299. module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5))
  300. if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
  301. module.wi_0.bias.data.zero_()
  302. module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5))
  303. if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
  304. module.wi_1.bias.data.zero_()
  305. module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5))
  306. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  307. module.wo.bias.data.zero_()
  308. elif isinstance(module, Pix2StructTextAttention):
  309. # Mesh TensorFlow attention initialization to avoid scaling before softmax
  310. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
  311. hidden_size = (
  312. self.config.text_config.hidden_size
  313. if isinstance(self.config, Pix2StructConfig)
  314. else self.config.hidden_size
  315. )
  316. key_value_proj_dim = (
  317. self.config.text_config.d_kv if isinstance(self.config, Pix2StructConfig) else self.config.hidden_size
  318. )
  319. n_heads = (
  320. self.config.text_config.num_heads
  321. if isinstance(self.config, Pix2StructConfig)
  322. else self.config.num_heads
  323. )
  324. module.query.weight.data.normal_(mean=0.0, std=factor * ((hidden_size * key_value_proj_dim) ** -0.5))
  325. module.key.weight.data.normal_(mean=0.0, std=factor * (hidden_size**-0.5))
  326. module.value.weight.data.normal_(mean=0.0, std=factor * (hidden_size**-0.5))
  327. module.output.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
  328. if module.has_relative_attention_bias:
  329. module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5))
  330. elif isinstance(module, nn.Embedding):
  331. hidden_size = (
  332. self.config.text_config.hidden_size
  333. if isinstance(self.config, Pix2StructConfig)
  334. else self.config.hidden_size
  335. )
  336. module.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5))
  337. if module.padding_idx is not None:
  338. module.weight.data[module.padding_idx].zero_()
  339. elif isinstance(module, Pix2StructTextModel):
  340. hidden_size = (
  341. self.config.text_config.hidden_size
  342. if isinstance(self.config, Pix2StructConfig)
  343. else self.config.hidden_size
  344. )
  345. module.lm_head.weight.data.normal_(mean=0.0, std=factor * ((hidden_size) ** -0.5))
  346. elif isinstance(module, (nn.Linear, nn.Conv2d)):
  347. # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
  348. # `trunc_normal_cpu` not implemented in `half` issues
  349. module.weight.data = nn.init.trunc_normal_(
  350. module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
  351. ).to(module.weight.dtype)
  352. if module.bias is not None:
  353. module.bias.data.zero_()
  354. elif isinstance(module, Pix2StructLayerNorm):
  355. if module.weight is not None:
  356. module.weight.data.fill_(1.0)
  357. elif isinstance(module, nn.Embedding):
  358. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  359. if module.padding_idx is not None:
  360. module.weight.data[module.padding_idx].zero_()
  361. # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->Pix2Struct
  362. def _shift_right(self, input_ids):
  363. decoder_start_token_id = self.config.decoder_start_token_id
  364. pad_token_id = self.config.pad_token_id
  365. if decoder_start_token_id is None:
  366. raise ValueError(
  367. "self.model.config.decoder_start_token_id has to be defined. In Pix2Struct it is usually set to the pad_token_id. "
  368. "See Pix2Struct docs for more information."
  369. )
  370. # shift inputs to the right
  371. if is_torch_fx_proxy(input_ids):
  372. # Item assignment is not supported natively for proxies.
  373. shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
  374. shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
  375. else:
  376. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  377. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  378. shifted_input_ids[..., 0] = decoder_start_token_id
  379. if pad_token_id is None:
  380. raise ValueError("self.model.config.pad_token_id has to be defined.")
  381. # replace possible -100 values in labels by `pad_token_id`
  382. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  383. return shifted_input_ids
  384. @auto_docstring
  385. class Pix2StructVisionModel(Pix2StructPreTrainedModel):
  386. config: Pix2StructVisionConfig
  387. main_input_name = "flattened_patches"
  388. supports_gradient_checkpointing = True
  389. _no_split_modules = ["Pix2StructVisionLayer"]
  390. def __init__(self, config: Pix2StructVisionConfig):
  391. super().__init__(config)
  392. self.config = config
  393. self.embeddings = Pix2StructVisionEmbeddings(config)
  394. self.encoder = Pix2StructVisionEncoder(config)
  395. self.layernorm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  396. # Initialize weights and apply final processing
  397. self.post_init()
  398. def get_input_embeddings(self):
  399. return self.embeddings.patch_projection
  400. def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None:
  401. """
  402. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  403. class PreTrainedModel
  404. """
  405. for layer, heads in heads_to_prune.items():
  406. self.encoder.layer[layer].attention.prune_heads(heads)
  407. @auto_docstring
  408. def forward(
  409. self,
  410. flattened_patches: Optional[torch.Tensor] = None,
  411. attention_mask: Optional[torch.Tensor] = None,
  412. head_mask: Optional[torch.Tensor] = None,
  413. output_attentions: Optional[bool] = None,
  414. output_hidden_states: Optional[bool] = None,
  415. return_dict: Optional[bool] = None,
  416. ) -> Union[tuple, BaseModelOutputWithPooling]:
  417. r"""
  418. flattened_patches (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_channels x patch_height x patch_width)`):
  419. Flattened and padded pixel values. These values can be obtained using [`AutoImageProcessor`]. See
  420. [`Pix2StructVisionImageProcessor.__call__`] for details. Check the [original
  421. paper](https://huggingface.co/papers/2210.03347) (figure 5) for more details.
  422. Example:
  423. ```python
  424. >>> import requests
  425. >>> from PIL import Image
  426. >>> from transformers import AutoProcessor, Pix2StructVisionModel
  427. >>> image_processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
  428. >>> model = Pix2StructVisionModel.from_pretrained("google/pix2struct-textcaps-base")
  429. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  430. >>> image = Image.open(requests.get(url, stream=True).raw)
  431. >>> inputs = image_processor(images=image, return_tensors="pt")
  432. >>> with torch.no_grad():
  433. ... outputs = model(**inputs)
  434. >>> last_hidden_states = outputs.last_hidden_state
  435. >>> list(last_hidden_states.shape)
  436. [1, 2048, 768]
  437. ```
  438. """
  439. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  440. output_hidden_states = (
  441. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  442. )
  443. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  444. if flattened_patches is None:
  445. raise ValueError("You have to specify flattened_patches")
  446. if attention_mask is None:
  447. # check where `flattened_patches` is not 0
  448. attention_mask = (flattened_patches.sum(dim=-1) != 0).float()
  449. # Prepare head mask if needed
  450. # 1.0 in head_mask indicate we keep the head
  451. # attention_probs has shape bsz x n_heads x N x N
  452. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  453. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  454. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  455. embedding_output = self.embeddings(flattened_patches)
  456. encoder_outputs = self.encoder(
  457. embedding_output,
  458. attention_mask=attention_mask,
  459. head_mask=head_mask,
  460. output_attentions=output_attentions,
  461. output_hidden_states=output_hidden_states,
  462. return_dict=return_dict,
  463. )
  464. sequence_output = encoder_outputs[0]
  465. sequence_output = self.layernorm(sequence_output)
  466. if not return_dict:
  467. head_outputs = (sequence_output,)
  468. return head_outputs + encoder_outputs[1:]
  469. return BaseModelOutput(
  470. last_hidden_state=sequence_output,
  471. hidden_states=encoder_outputs.hidden_states,
  472. attentions=encoder_outputs.attentions,
  473. )
  474. # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->Pix2StructText,d_model->hidden_size
  475. class Pix2StructTextDenseGatedActDense(nn.Module):
  476. def __init__(self, config: Pix2StructTextConfig):
  477. super().__init__()
  478. self.wi_0 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
  479. self.wi_1 = nn.Linear(config.hidden_size, config.d_ff, bias=False)
  480. self.wo = nn.Linear(config.d_ff, config.hidden_size, bias=False)
  481. self.dropout = nn.Dropout(config.dropout_rate)
  482. self.act = ACT2FN[config.dense_act_fn]
  483. def forward(self, hidden_states):
  484. hidden_gelu = self.act(self.wi_0(hidden_states))
  485. hidden_linear = self.wi_1(hidden_states)
  486. hidden_states = hidden_gelu * hidden_linear
  487. hidden_states = self.dropout(hidden_states)
  488. # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
  489. # See https://github.com/huggingface/transformers/issues/20287
  490. # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
  491. if (
  492. isinstance(self.wo.weight, torch.Tensor)
  493. and hidden_states.dtype != self.wo.weight.dtype
  494. and self.wo.weight.dtype != torch.int8
  495. ):
  496. hidden_states = hidden_states.to(self.wo.weight.dtype)
  497. hidden_states = self.wo(hidden_states)
  498. return hidden_states
  499. class Pix2StructTextLayerFF(nn.Module):
  500. def __init__(self, config: Pix2StructTextConfig):
  501. super().__init__()
  502. self.DenseReluDense = Pix2StructTextDenseGatedActDense(config)
  503. self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  504. self.dropout = nn.Dropout(config.dropout_rate)
  505. # Copied from transformers.models.t5.modeling_t5.T5LayerFF.forward
  506. def forward(self, hidden_states):
  507. forwarded_states = self.layer_norm(hidden_states)
  508. forwarded_states = self.DenseReluDense(forwarded_states)
  509. hidden_states = hidden_states + self.dropout(forwarded_states)
  510. return hidden_states
  511. class Pix2StructTextAttention(nn.Module):
  512. def __init__(
  513. self, config: Pix2StructTextConfig, has_relative_attention_bias=False, layer_idx: Optional[int] = None
  514. ):
  515. super().__init__()
  516. self.has_relative_attention_bias = has_relative_attention_bias
  517. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  518. self.relative_attention_max_distance = config.relative_attention_max_distance
  519. self.hidden_size = config.hidden_size
  520. self.key_value_proj_dim = config.d_kv
  521. self.n_heads = config.num_heads
  522. self.dropout = config.dropout_rate
  523. self.inner_dim = self.n_heads * self.key_value_proj_dim
  524. self.layer_idx = layer_idx
  525. if layer_idx is None:
  526. logger.warning_once(
  527. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  528. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  529. "when creating this class."
  530. )
  531. # Mesh TensorFlow initialization to avoid scaling before softmax
  532. self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  533. self.key = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  534. self.value = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  535. self.output = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
  536. if self.has_relative_attention_bias:
  537. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  538. self.pruned_heads = set()
  539. self.gradient_checkpointing = False
  540. @staticmethod
  541. # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket
  542. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  543. """
  544. Adapted from Mesh Tensorflow:
  545. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  546. Translate relative position to a bucket number for relative attention. The relative position is defined as
  547. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  548. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  549. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  550. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  551. This should allow for more graceful generalization to longer sequences than the model has been trained on
  552. Args:
  553. relative_position: an int32 Tensor
  554. bidirectional: a boolean - whether the attention is bidirectional
  555. num_buckets: an integer
  556. max_distance: an integer
  557. Returns:
  558. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  559. """
  560. relative_buckets = 0
  561. if bidirectional:
  562. num_buckets //= 2
  563. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  564. relative_position = torch.abs(relative_position)
  565. else:
  566. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  567. # now relative_position is in the range [0, inf)
  568. # half of the buckets are for exact increments in positions
  569. max_exact = num_buckets // 2
  570. is_small = relative_position < max_exact
  571. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  572. relative_position_if_large = max_exact + (
  573. torch.log(relative_position.float() / max_exact)
  574. / math.log(max_distance / max_exact)
  575. * (num_buckets - max_exact)
  576. ).to(torch.long)
  577. relative_position_if_large = torch.min(
  578. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  579. )
  580. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  581. return relative_buckets
  582. # Adapted from transformers.models.t5.modeling_t5.T5Attention.compute_bias
  583. def compute_bias(self, query_length, key_length, device=None, cache_position=None):
  584. """Compute binned relative position bias"""
  585. if device is None:
  586. device = self.relative_attention_bias.weight.device
  587. if cache_position is None:
  588. context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
  589. else:
  590. context_position = cache_position[:, None].to(device)
  591. memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
  592. relative_position = memory_position - context_position # shape (query_length, key_length)
  593. relative_position_bucket = self._relative_position_bucket(
  594. relative_position, # shape (query_length, key_length)
  595. bidirectional=False,
  596. num_buckets=self.relative_attention_num_buckets,
  597. max_distance=self.relative_attention_max_distance,
  598. )
  599. values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
  600. values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
  601. return values
  602. # Adapted from transformers.models.t5.modeling_t5.T5Attention.forward
  603. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  604. def forward(
  605. self,
  606. hidden_states,
  607. mask=None,
  608. key_value_states=None,
  609. position_bias=None,
  610. past_key_values=None,
  611. layer_head_mask=None,
  612. query_length=None,
  613. use_cache=False,
  614. output_attentions=False,
  615. cache_position=None,
  616. ):
  617. """
  618. Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
  619. """
  620. # Input is (batch_size, seq_length, dim)
  621. # Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, seq_length, key_length) (causal decoder)
  622. batch_size, seq_length = hidden_states.shape[:2]
  623. # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
  624. is_cross_attention = key_value_states is not None
  625. query_states = self.query(hidden_states)
  626. query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  627. # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
  628. if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache):
  629. is_updated = past_key_values.is_updated.get(self.layer_idx)
  630. if is_cross_attention:
  631. # after the first generated id, we can subsequently re-use all key/value_states from cache
  632. curr_past_key_value = past_key_values.cross_attention_cache
  633. else:
  634. curr_past_key_value = past_key_values.self_attention_cache
  635. else:
  636. curr_past_key_value = past_key_values
  637. current_states = key_value_states if is_cross_attention else hidden_states
  638. if is_cross_attention and past_key_values and is_updated:
  639. # reuse k,v, cross_attentions
  640. key_states = curr_past_key_value.layers[self.layer_idx].keys
  641. value_states = curr_past_key_value.layers[self.layer_idx].values
  642. else:
  643. key_states = self.key(current_states)
  644. value_states = self.value(current_states)
  645. key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  646. value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  647. if past_key_values is not None:
  648. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  649. cache_position = cache_position if not is_cross_attention else None
  650. key_states, value_states = curr_past_key_value.update(
  651. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  652. )
  653. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  654. if is_cross_attention:
  655. past_key_values.is_updated[self.layer_idx] = True
  656. # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  657. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  658. if position_bias is None:
  659. key_length = key_states.shape[-2]
  660. # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
  661. real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
  662. if not self.has_relative_attention_bias:
  663. position_bias = torch.zeros(
  664. (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
  665. )
  666. if self.gradient_checkpointing and self.training:
  667. position_bias.requires_grad = True
  668. else:
  669. position_bias = self.compute_bias(
  670. real_seq_length, key_length, device=scores.device, cache_position=cache_position
  671. )
  672. position_bias = position_bias[:, :, -seq_length:, :]
  673. if mask is not None:
  674. causal_mask = mask[:, :, :, : key_states.shape[-2]]
  675. position_bias = position_bias + causal_mask
  676. if self.pruned_heads:
  677. mask = torch.ones(position_bias.shape[1])
  678. mask[list(self.pruned_heads)] = 0
  679. position_bias_masked = position_bias[:, mask.bool()]
  680. else:
  681. position_bias_masked = position_bias
  682. scores += position_bias_masked
  683. # (batch_size, n_heads, seq_length, key_length)
  684. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  685. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  686. # Mask heads if we want to
  687. if layer_head_mask is not None:
  688. attn_weights = attn_weights * layer_head_mask
  689. attn_output = torch.matmul(attn_weights, value_states)
  690. attn_output = attn_output.transpose(1, 2).contiguous()
  691. attn_output = attn_output.view(batch_size, -1, self.inner_dim)
  692. attn_output = self.output(attn_output)
  693. outputs = (attn_output, position_bias)
  694. if output_attentions:
  695. outputs = outputs + (attn_weights,)
  696. return outputs
  697. # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerSelfAttention->Pix2StructTextLayerSelfAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size
  698. class Pix2StructTextLayerSelfAttention(nn.Module):
  699. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  700. super().__init__()
  701. self.attention = Pix2StructTextAttention(
  702. config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
  703. )
  704. self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  705. self.dropout = nn.Dropout(config.dropout_rate)
  706. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  707. def forward(
  708. self,
  709. hidden_states,
  710. attention_mask=None,
  711. position_bias=None,
  712. layer_head_mask=None,
  713. past_key_values=None,
  714. use_cache=False,
  715. output_attentions=False,
  716. cache_position=None,
  717. ):
  718. normed_hidden_states = self.layer_norm(hidden_states)
  719. attention_output = self.attention(
  720. normed_hidden_states,
  721. mask=attention_mask,
  722. position_bias=position_bias,
  723. layer_head_mask=layer_head_mask,
  724. past_key_values=past_key_values,
  725. use_cache=use_cache,
  726. output_attentions=output_attentions,
  727. cache_position=cache_position,
  728. )
  729. hidden_states = hidden_states + self.dropout(attention_output[0])
  730. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  731. return outputs
  732. # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerCrossAttention->Pix2StructTextLayerCrossAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size
  733. class Pix2StructTextLayerCrossAttention(nn.Module):
  734. def __init__(self, config, layer_idx: Optional[int] = None):
  735. super().__init__()
  736. self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
  737. self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  738. self.dropout = nn.Dropout(config.dropout_rate)
  739. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  740. def forward(
  741. self,
  742. hidden_states,
  743. key_value_states,
  744. attention_mask=None,
  745. position_bias=None,
  746. layer_head_mask=None,
  747. past_key_values=None,
  748. use_cache=False,
  749. query_length=None,
  750. output_attentions=False,
  751. cache_position=None,
  752. ):
  753. normed_hidden_states = self.layer_norm(hidden_states)
  754. attention_output = self.attention(
  755. normed_hidden_states,
  756. mask=attention_mask,
  757. key_value_states=key_value_states,
  758. position_bias=position_bias,
  759. layer_head_mask=layer_head_mask,
  760. past_key_values=past_key_values,
  761. use_cache=use_cache,
  762. query_length=query_length,
  763. output_attentions=output_attentions,
  764. cache_position=cache_position,
  765. )
  766. layer_output = hidden_states + self.dropout(attention_output[0])
  767. outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
  768. return outputs
  769. class Pix2StructTextBlock(GradientCheckpointingLayer):
  770. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  771. super().__init__()
  772. self.self_attention = Pix2StructTextLayerSelfAttention(
  773. config,
  774. has_relative_attention_bias=has_relative_attention_bias,
  775. layer_idx=layer_idx,
  776. )
  777. self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(
  778. config,
  779. layer_idx=layer_idx,
  780. )
  781. self.mlp = Pix2StructTextLayerFF(config)
  782. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  783. def forward(
  784. self,
  785. hidden_states,
  786. attention_mask=None,
  787. position_bias=None,
  788. encoder_hidden_states=None,
  789. encoder_attention_mask=None,
  790. encoder_decoder_position_bias=None,
  791. layer_head_mask=None,
  792. cross_attn_layer_head_mask=None,
  793. past_key_values=None,
  794. use_cache=False,
  795. output_attentions=False,
  796. return_dict=True,
  797. cache_position=None,
  798. ):
  799. self_attention_outputs = self.self_attention(
  800. hidden_states,
  801. attention_mask=attention_mask,
  802. position_bias=position_bias,
  803. layer_head_mask=layer_head_mask,
  804. past_key_values=past_key_values,
  805. use_cache=use_cache,
  806. output_attentions=output_attentions,
  807. cache_position=cache_position,
  808. )
  809. hidden_states = self_attention_outputs[0]
  810. attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
  811. # clamp inf values to enable fp16 training
  812. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  813. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  814. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  815. do_cross_attention = encoder_hidden_states is not None
  816. if do_cross_attention:
  817. cross_attention_outputs = self.encoder_decoder_attention(
  818. hidden_states,
  819. key_value_states=encoder_hidden_states,
  820. attention_mask=encoder_attention_mask,
  821. position_bias=encoder_decoder_position_bias,
  822. layer_head_mask=cross_attn_layer_head_mask,
  823. past_key_values=past_key_values,
  824. query_length=cache_position[-1] + 1,
  825. use_cache=use_cache,
  826. output_attentions=output_attentions,
  827. )
  828. hidden_states = cross_attention_outputs[0]
  829. # clamp inf values to enable fp16 training
  830. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  831. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  832. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  833. # Keep cross-attention outputs and relative position weights
  834. attention_outputs = attention_outputs + cross_attention_outputs[1:]
  835. # Apply Feed Forward layer
  836. hidden_states = self.mlp(hidden_states)
  837. # clamp inf values to enable fp16 training
  838. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  839. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  840. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  841. outputs = (hidden_states,)
  842. return outputs + attention_outputs
  843. @auto_docstring(
  844. custom_intro="""
  845. The standalone text decoder of Pix2Struct
  846. """
  847. )
  848. class Pix2StructTextModel(Pix2StructPreTrainedModel):
  849. config: Pix2StructTextConfig
  850. _no_split_modules = ["Pix2StructTextBlock"]
  851. _tied_weights_keys = ["lm_head.weight"]
  852. supports_gradient_checkpointing = True
  853. def __init__(self, config):
  854. super().__init__(config)
  855. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
  856. self.layer = nn.ModuleList(
  857. [
  858. Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i)
  859. for i in range(config.num_layers)
  860. ]
  861. )
  862. self.final_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
  863. self.dropout = nn.Dropout(config.dropout_rate)
  864. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  865. # Initialize weights and apply final processing
  866. self.post_init()
  867. self.gradient_checkpointing = False
  868. def set_input_embeddings(self, new_embeddings):
  869. self.embed_tokens = new_embeddings
  870. @auto_docstring
  871. def forward(
  872. self,
  873. input_ids: Optional[torch.LongTensor] = None,
  874. attention_mask: Optional[torch.FloatTensor] = None,
  875. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  876. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  877. inputs_embeds: Optional[torch.LongTensor] = None,
  878. head_mask: Optional[torch.FloatTensor] = None,
  879. cross_attn_head_mask: Optional[torch.Tensor] = None,
  880. past_key_values: Optional[Cache] = None,
  881. use_cache: Optional[bool] = None,
  882. output_attentions: Optional[bool] = None,
  883. output_hidden_states: Optional[bool] = None,
  884. labels: Optional[torch.LongTensor] = None,
  885. return_dict: Optional[bool] = None,
  886. cache_position: Optional[torch.LongTensor] = None,
  887. **kwargs,
  888. ) -> Union[tuple[torch.FloatTensor, ...], CausalLMOutputWithCrossAttentions]:
  889. r"""
  890. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  891. Indices of input sequence tokens in the vocabulary. Pix2StructText is a model with relative position
  892. embeddings so you should be able to pad the inputs on both the right and the left.
  893. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  894. [`PreTrainedTokenizer.__call__`] for detail.
  895. [What are input IDs?](../glossary#input-ids)
  896. To know more on how to prepare `input_ids` for pretraining take a look a [Pix2StructText
  897. Training](./t5#training).
  898. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  899. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  900. `[0, 1]`:
  901. - 1 indicates the head is **not masked**,
  902. - 0 indicates the head is **masked**.
  903. Example:
  904. ```python
  905. >>> from transformers import AutoProcessor, Pix2StructTextModel
  906. >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
  907. >>> model = Pix2StructTextModel.from_pretrained("google/pix2struct-textcaps-base")
  908. >>> inputs = processor(text="Hello, my dog is cute", return_tensors="pt")
  909. >>> outputs = model(**inputs)
  910. >>> loss = outputs.loss
  911. ```
  912. """
  913. use_cache = use_cache if use_cache is not None else self.config.use_cache
  914. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  915. output_hidden_states = (
  916. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  917. )
  918. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  919. if self.gradient_checkpointing and self.training and use_cache:
  920. logger.warning(
  921. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  922. )
  923. use_cache = False
  924. if input_ids is not None and inputs_embeds is not None:
  925. raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
  926. elif input_ids is not None:
  927. input_shape = input_ids.size()
  928. input_ids = input_ids.view(-1, input_shape[-1])
  929. elif inputs_embeds is not None:
  930. input_shape = inputs_embeds.size()[:-1]
  931. else:
  932. raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
  933. if inputs_embeds is None:
  934. assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
  935. inputs_embeds = self.embed_tokens(input_ids)
  936. batch_size, seq_length = input_shape
  937. if use_cache and past_key_values is None:
  938. if self.config.is_encoder_decoder:
  939. past_key_values = EncoderDecoderCache(
  940. DynamicCache(config=self.config), DynamicCache(config=self.config)
  941. )
  942. else:
  943. past_key_values = DynamicCache(config=self.config)
  944. past_key_values_length = 0
  945. if cache_position is not None:
  946. past_key_values_length = cache_position[0]
  947. elif past_key_values is not None:
  948. past_key_values_length = past_key_values.get_seq_length()
  949. if cache_position is None:
  950. cache_position = torch.arange(
  951. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  952. )
  953. if attention_mask is None:
  954. # required mask seq length can be calculated via length of past
  955. mask_seq_length = (
  956. past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length
  957. )
  958. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  959. if self.config.is_decoder:
  960. causal_mask = self._update_causal_mask(
  961. attention_mask,
  962. inputs_embeds,
  963. cache_position,
  964. past_key_values.self_attention_cache
  965. if isinstance(past_key_values, EncoderDecoderCache)
  966. else past_key_values,
  967. output_attentions,
  968. )
  969. else:
  970. causal_mask = attention_mask[:, None, None, :]
  971. causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
  972. causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
  973. # If a 2D or 3D attention mask is provided for the cross-attention
  974. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  975. if encoder_hidden_states is not None:
  976. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  977. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  978. if encoder_attention_mask is None:
  979. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
  980. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  981. else:
  982. encoder_extended_attention_mask = None
  983. # Prepare head mask if needed
  984. head_mask = self.get_head_mask(head_mask, self.config.num_layers)
  985. cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
  986. all_hidden_states = () if output_hidden_states else None
  987. all_attentions = () if output_attentions else None
  988. all_cross_attentions = () if (output_attentions) else None
  989. position_bias = None
  990. encoder_decoder_position_bias = None
  991. hidden_states = self.dropout(inputs_embeds)
  992. for i, layer_module in enumerate(self.layer):
  993. layer_head_mask = head_mask[i]
  994. cross_attn_layer_head_mask = cross_attn_head_mask[i]
  995. if output_hidden_states:
  996. all_hidden_states = all_hidden_states + (hidden_states,)
  997. layer_outputs = layer_module(
  998. hidden_states,
  999. causal_mask,
  1000. position_bias,
  1001. encoder_hidden_states,
  1002. encoder_extended_attention_mask,
  1003. encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
  1004. layer_head_mask=layer_head_mask,
  1005. cross_attn_layer_head_mask=cross_attn_layer_head_mask,
  1006. past_key_values=past_key_values,
  1007. use_cache=use_cache,
  1008. output_attentions=output_attentions,
  1009. cache_position=cache_position,
  1010. )
  1011. hidden_states = layer_outputs[0]
  1012. # We share the position biases between the layers - the first layer store them
  1013. # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
  1014. # (cross-attention position bias), (cross-attention weights)
  1015. position_bias = layer_outputs[1]
  1016. if encoder_hidden_states is not None:
  1017. encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
  1018. if output_attentions:
  1019. all_attentions = all_attentions + (layer_outputs[2],)
  1020. if encoder_hidden_states is not None:
  1021. all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
  1022. hidden_states = self.final_layer_norm(hidden_states)
  1023. hidden_states = self.dropout(hidden_states)
  1024. logits = self.lm_head(hidden_states)
  1025. # Add last layer
  1026. if output_hidden_states:
  1027. all_hidden_states = all_hidden_states + (hidden_states,)
  1028. loss = None
  1029. if labels is not None:
  1030. # move labels to correct device to enable model parallelism
  1031. labels = labels.to(logits.device)
  1032. loss_fct = nn.CrossEntropyLoss(ignore_index=-100, reduction="mean")
  1033. loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), labels.contiguous().view(-1))
  1034. if not return_dict:
  1035. return tuple(
  1036. v
  1037. for v in [
  1038. loss,
  1039. logits,
  1040. past_key_values,
  1041. all_hidden_states,
  1042. all_attentions,
  1043. all_cross_attentions,
  1044. ]
  1045. if v is not None
  1046. )
  1047. return CausalLMOutputWithCrossAttentions(
  1048. loss=loss,
  1049. logits=logits,
  1050. past_key_values=past_key_values,
  1051. hidden_states=all_hidden_states,
  1052. attentions=all_attentions,
  1053. cross_attentions=all_cross_attentions,
  1054. )
  1055. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
  1056. def _update_causal_mask(
  1057. self,
  1058. attention_mask: Union[torch.Tensor, "BlockMask"],
  1059. input_tensor: torch.Tensor,
  1060. cache_position: torch.Tensor,
  1061. past_key_values: Cache,
  1062. output_attentions: bool = False,
  1063. ):
  1064. if self.config._attn_implementation == "flash_attention_2":
  1065. if attention_mask is not None and (attention_mask == 0.0).any():
  1066. return attention_mask
  1067. return None
  1068. if self.config._attn_implementation == "flex_attention":
  1069. if isinstance(attention_mask, torch.Tensor):
  1070. attention_mask = make_flex_block_causal_mask(attention_mask)
  1071. return attention_mask
  1072. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  1073. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  1074. # to infer the attention mask.
  1075. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1076. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  1077. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  1078. if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
  1079. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  1080. attention_mask,
  1081. inputs_embeds=input_tensor,
  1082. past_key_values_length=past_seen_tokens,
  1083. is_training=self.training,
  1084. ):
  1085. return None
  1086. dtype = input_tensor.dtype
  1087. sequence_length = input_tensor.shape[1]
  1088. if using_compilable_cache:
  1089. target_length = past_key_values.get_max_cache_shape()
  1090. else:
  1091. target_length = (
  1092. attention_mask.shape[-1]
  1093. if isinstance(attention_mask, torch.Tensor)
  1094. else past_seen_tokens + sequence_length + 1
  1095. )
  1096. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  1097. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  1098. attention_mask,
  1099. sequence_length=sequence_length,
  1100. target_length=target_length,
  1101. dtype=dtype,
  1102. cache_position=cache_position,
  1103. batch_size=input_tensor.shape[0],
  1104. )
  1105. if (
  1106. self.config._attn_implementation == "sdpa"
  1107. and attention_mask is not None
  1108. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  1109. and not output_attentions
  1110. ):
  1111. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  1112. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  1113. # Details: https://github.com/pytorch/pytorch/issues/110213
  1114. min_dtype = torch.finfo(dtype).min
  1115. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  1116. return causal_mask
  1117. @staticmethod
  1118. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  1119. def _prepare_4d_causal_attention_mask_with_cache_position(
  1120. attention_mask: torch.Tensor,
  1121. sequence_length: int,
  1122. target_length: int,
  1123. dtype: torch.dtype,
  1124. cache_position: torch.Tensor,
  1125. batch_size: int,
  1126. **kwargs,
  1127. ):
  1128. """
  1129. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  1130. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  1131. Args:
  1132. attention_mask (`torch.Tensor`):
  1133. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  1134. `(batch_size, 1, query_length, key_value_length)`.
  1135. sequence_length (`int`):
  1136. The sequence length being processed.
  1137. target_length (`int`):
  1138. The target length: when generating with static cache, the mask should be as long as the static cache,
  1139. to account for the 0 padding, the part of the cache that is not filled yet.
  1140. dtype (`torch.dtype`):
  1141. The dtype to use for the 4D attention mask.
  1142. cache_position (`torch.Tensor`):
  1143. Indices depicting the position of the input sequence tokens in the sequence.
  1144. batch_size (`torch.Tensor`):
  1145. Batch size.
  1146. """
  1147. if attention_mask is not None and attention_mask.dim() == 4:
  1148. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  1149. causal_mask = attention_mask
  1150. else:
  1151. min_dtype = torch.finfo(dtype).min
  1152. causal_mask = torch.full(
  1153. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  1154. )
  1155. if sequence_length != 1:
  1156. causal_mask = torch.triu(causal_mask, diagonal=1)
  1157. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  1158. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  1159. if attention_mask is not None:
  1160. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  1161. mask_length = attention_mask.shape[-1]
  1162. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  1163. causal_mask.device
  1164. )
  1165. padding_mask = padding_mask == 0
  1166. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  1167. padding_mask, min_dtype
  1168. )
  1169. return causal_mask
  1170. @auto_docstring(
  1171. custom_intro="""
  1172. A conditional generation model with a language modeling head. Can be used for sequence generation tasks.
  1173. """
  1174. )
  1175. class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel, GenerationMixin):
  1176. config: Pix2StructConfig
  1177. main_input_name = "flattened_patches"
  1178. _tied_weights_keys = ["decoder.lm_head.weight"]
  1179. def __init__(self, config: Pix2StructConfig):
  1180. super().__init__(config)
  1181. self.encoder = Pix2StructVisionModel(config.vision_config)
  1182. self.decoder = Pix2StructTextModel(config.text_config)
  1183. self.is_vqa = config.is_vqa
  1184. # Initialize weights and apply final processing
  1185. self.post_init()
  1186. def get_input_embeddings(self):
  1187. return self.decoder.get_input_embeddings()
  1188. def set_input_embeddings(self, new_embeddings):
  1189. self.decoder.set_input_embeddings(new_embeddings)
  1190. def get_output_embeddings(self) -> nn.Module:
  1191. return self.decoder.get_output_embeddings()
  1192. def set_output_embeddings(self, new_embeddings):
  1193. self.decoder.set_output_embeddings(new_embeddings)
  1194. def get_encoder(self):
  1195. return self.encoder
  1196. @auto_docstring
  1197. def forward(
  1198. self,
  1199. flattened_patches: Optional[torch.FloatTensor] = None,
  1200. attention_mask: Optional[torch.FloatTensor] = None,
  1201. decoder_input_ids: Optional[torch.LongTensor] = None,
  1202. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1203. head_mask: Optional[torch.FloatTensor] = None,
  1204. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1205. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1206. encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
  1207. past_key_values: Optional[Cache] = None,
  1208. labels: Optional[torch.LongTensor] = None,
  1209. decoder_inputs_embeds: Optional[torch.Tensor] = None,
  1210. use_cache: Optional[bool] = None,
  1211. output_attentions: Optional[bool] = None,
  1212. output_hidden_states: Optional[bool] = None,
  1213. return_dict: Optional[bool] = None,
  1214. cache_position: Optional[torch.LongTensor] = None,
  1215. ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]:
  1216. r"""
  1217. flattened_patches (`torch.FloatTensor` of shape `(batch_size, seq_length, hidden_size)`):
  1218. Flattened pixel patches. the `hidden_size` is obtained by the following formula: `hidden_size` =
  1219. `num_channels` * `patch_size` * `patch_size`
  1220. The process of flattening the pixel patches is done by `Pix2StructProcessor`.
  1221. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1222. Indices of decoder input sequence tokens in the vocabulary.
  1223. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1224. [`PreTrainedTokenizer.__call__`] for details.
  1225. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1226. Pix2StructText uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  1227. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1228. `past_key_values`).
  1229. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [Pix2StructText
  1230. Training](./t5#training).
  1231. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1232. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1233. be used by default.
  1234. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1235. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1236. 1]`:
  1237. - 1 indicates the head is **not masked**,
  1238. - 0 indicates the head is **masked**.
  1239. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1240. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1241. `[0, 1]`:
  1242. - 1 indicates the head is **not masked**,
  1243. - 0 indicates the head is **masked**.
  1244. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1245. Labels for computing the masked language modeling loss for the decoder.
  1246. Example:
  1247. Inference:
  1248. ```python
  1249. >>> from PIL import Image
  1250. >>> import requests
  1251. >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration
  1252. >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
  1253. >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
  1254. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  1255. >>> image = Image.open(requests.get(url, stream=True).raw)
  1256. >>> inputs = processor(images=image, return_tensors="pt")
  1257. >>> # autoregressive generation
  1258. >>> generated_ids = model.generate(**inputs, max_new_tokens=50)
  1259. >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  1260. >>> print(generated_text)
  1261. A stop sign is on a street corner.
  1262. >>> # conditional generation
  1263. >>> text = "A picture of"
  1264. >>> inputs = processor(text=text, images=image, return_tensors="pt", add_special_tokens=False)
  1265. >>> generated_ids = model.generate(**inputs, max_new_tokens=50)
  1266. >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  1267. >>> print(generated_text)
  1268. A picture of a stop sign with a red stop sign
  1269. ```
  1270. Training:
  1271. ```python
  1272. >>> from PIL import Image
  1273. >>> import requests
  1274. >>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration
  1275. >>> processor = AutoProcessor.from_pretrained("google/pix2struct-base")
  1276. >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-base")
  1277. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  1278. >>> image = Image.open(requests.get(url, stream=True).raw)
  1279. >>> text = "A stop sign is on the street corner."
  1280. >>> inputs = processor(images=image, return_tensors="pt")
  1281. >>> labels = processor(text=text, return_tensors="pt").input_ids
  1282. >>> # forward pass
  1283. >>> outputs = model(**inputs, labels=labels)
  1284. >>> loss = outputs.loss
  1285. >>> print(f"{loss.item():.5f}")
  1286. 5.94282
  1287. ```"""
  1288. use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache
  1289. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1290. # Encode if needed (training, first prediction pass)
  1291. if encoder_outputs is None:
  1292. encoder_outputs = self.encoder(
  1293. flattened_patches=flattened_patches,
  1294. attention_mask=attention_mask,
  1295. head_mask=head_mask,
  1296. output_attentions=output_attentions,
  1297. output_hidden_states=output_hidden_states,
  1298. return_dict=return_dict,
  1299. )
  1300. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1301. encoder_outputs = BaseModelOutput(
  1302. last_hidden_state=encoder_outputs[0],
  1303. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1304. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1305. )
  1306. hidden_states = encoder_outputs[0]
  1307. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  1308. # get decoder inputs from shifting lm labels to the right
  1309. decoder_input_ids = self._shift_right(labels)
  1310. decoder_attention_mask = (
  1311. decoder_attention_mask
  1312. if decoder_attention_mask is not None
  1313. else decoder_input_ids.ne(self.config.pad_token_id).float()
  1314. )
  1315. # Always attend to the first token
  1316. decoder_attention_mask[:, 0] = 1
  1317. # Decode
  1318. decoder_outputs = self.decoder(
  1319. input_ids=decoder_input_ids,
  1320. attention_mask=decoder_attention_mask,
  1321. inputs_embeds=decoder_inputs_embeds,
  1322. past_key_values=past_key_values,
  1323. encoder_hidden_states=hidden_states,
  1324. encoder_attention_mask=attention_mask,
  1325. head_mask=decoder_head_mask,
  1326. cross_attn_head_mask=cross_attn_head_mask,
  1327. use_cache=use_cache,
  1328. output_attentions=output_attentions,
  1329. output_hidden_states=output_hidden_states,
  1330. labels=labels,
  1331. return_dict=return_dict,
  1332. cache_position=cache_position,
  1333. )
  1334. if not return_dict:
  1335. return decoder_outputs + encoder_outputs
  1336. return Seq2SeqLMOutput(
  1337. loss=decoder_outputs.loss,
  1338. logits=decoder_outputs.logits,
  1339. past_key_values=decoder_outputs.past_key_values,
  1340. decoder_hidden_states=decoder_outputs.hidden_states,
  1341. decoder_attentions=decoder_outputs.attentions,
  1342. cross_attentions=decoder_outputs.cross_attentions,
  1343. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1344. encoder_hidden_states=encoder_outputs.hidden_states,
  1345. encoder_attentions=encoder_outputs.attentions,
  1346. )
  1347. __all__ = [
  1348. "Pix2StructPreTrainedModel",
  1349. "Pix2StructForConditionalGeneration",
  1350. "Pix2StructVisionModel",
  1351. "Pix2StructTextModel",
  1352. ]