modeling_udop.py 90 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007
  1. # coding=utf-8
  2. # Copyright 2024 Microsoft Research and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch UDOP model."""
  16. import collections
  17. import logging
  18. import math
  19. import random
  20. from abc import ABC, abstractmethod
  21. from collections.abc import Sequence
  22. from copy import deepcopy
  23. from dataclasses import dataclass
  24. from typing import Any, Optional, Union
  25. import torch
  26. from torch import Tensor, nn
  27. from torch.nn import CrossEntropyLoss
  28. from transformers import UdopConfig
  29. from transformers.modeling_outputs import (
  30. Seq2SeqLMOutput,
  31. Seq2SeqModelOutput,
  32. )
  33. from ...activations import ACT2FN
  34. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  35. from ...generation import GenerationMixin
  36. from ...modeling_attn_mask_utils import AttentionMaskConverter
  37. from ...modeling_layers import GradientCheckpointingLayer
  38. from ...modeling_utils import PreTrainedModel
  39. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  40. from ...utils import (
  41. ModelOutput,
  42. auto_docstring,
  43. is_torch_flex_attn_available,
  44. is_torchdynamo_compiling,
  45. )
  46. from ...utils.deprecation import deprecate_kwarg
  47. if is_torch_flex_attn_available():
  48. from torch.nn.attention.flex_attention import BlockMask
  49. from ...integrations.flex_attention import make_flex_block_causal_mask
  50. logger = logging.getLogger(__name__)
  51. @dataclass
  52. @auto_docstring(
  53. custom_intro="""
  54. Class for the model's outputs that may also contain a past key/values (to speed up sequential decoding). Includes
  55. an additional attention mask.
  56. """
  57. )
  58. class BaseModelOutputWithAttentionMask(ModelOutput):
  59. r"""
  60. last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
  61. Sequence of hidden-states at the output of the last layer of the model. If `past_key_values` is used only
  62. the last hidden-state of the sequences of shape `(batch_size, 1, hidden_size)` is output.
  63. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  64. Attention mask used in the model's forward pass to avoid performing attention on padding token indices.
  65. Mask values selected in `[0, 1]`:
  66. - 1 for tokens that are **not masked**,
  67. - 0 for tokens that are **masked**.
  68. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  69. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  70. Contains pre-computed hidden-states (key and values in the
  71. self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks)
  72. that can be used (see `past_key_values` input) to speed up sequential decoding.
  73. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  74. Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  75. one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of
  76. the model at the output of each layer plus the optional initial embedding outputs.
  77. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  78. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  79. sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
  80. the self-attention heads.
  81. cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
  82. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  83. sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
  84. used to compute the weighted average in the cross-attention heads.
  85. """
  86. last_hidden_state: Optional[torch.FloatTensor] = None
  87. attention_mask: Optional[torch.FloatTensor] = None
  88. past_key_values: Optional[Cache] = None
  89. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  90. attentions: Optional[tuple[torch.FloatTensor]] = None
  91. cross_attentions: Optional[tuple[torch.FloatTensor]] = None
  92. def get_visual_bbox(image_size=224, patch_size=16):
  93. image_feature_pool_shape = [image_size // patch_size, image_size // patch_size]
  94. visual_bbox_x = torch.arange(0, 1.0 * (image_feature_pool_shape[1] + 1), 1.0)
  95. visual_bbox_x /= image_feature_pool_shape[1]
  96. visual_bbox_y = torch.arange(0, 1.0 * (image_feature_pool_shape[0] + 1), 1.0)
  97. visual_bbox_y /= image_feature_pool_shape[0]
  98. visual_bbox_input = torch.stack(
  99. [
  100. visual_bbox_x[:-1].repeat(image_feature_pool_shape[0], 1),
  101. visual_bbox_y[:-1].repeat(image_feature_pool_shape[1], 1).transpose(0, 1),
  102. visual_bbox_x[1:].repeat(image_feature_pool_shape[0], 1),
  103. visual_bbox_y[1:].repeat(image_feature_pool_shape[1], 1).transpose(0, 1),
  104. ],
  105. dim=-1,
  106. )
  107. visual_bbox_input = visual_bbox_input.view(-1, 4)
  108. return visual_bbox_input
  109. def pad_sequence(seq, target_len, pad_value=0):
  110. if isinstance(seq, torch.Tensor):
  111. n = seq.shape[0]
  112. else:
  113. n = len(seq)
  114. seq = torch.tensor(seq)
  115. m = target_len - n
  116. if m > 0:
  117. ret = torch.stack([pad_value] * m).to(seq)
  118. seq = torch.cat([seq, ret], dim=0)
  119. return seq[:target_len]
  120. def combine_image_text_embeddings(
  121. image_embeddings,
  122. inputs_embeds,
  123. bbox,
  124. visual_bbox,
  125. attention_mask=None,
  126. num_patches=14,
  127. max_len=0,
  128. image_size=224,
  129. patch_size=16,
  130. ):
  131. """
  132. Combine the image and text embeddings for the input to the encoder/decoder of UDOP.
  133. First, the image embeddings are created by checking for each visual patch if it is inside the bounding box of a
  134. token. If it is, the visual patch is combined with the token embedding. Then, the visual bounding boxes are combined
  135. with the text bounding boxes. Finally, the visual bounding boxes are combined with the text attention mask.
  136. """
  137. sequence_length = num_patches
  138. ocr_points_x = torch.clip(
  139. torch.floor((bbox[:, :, 0] + bbox[:, :, 2]) / 2.0 * sequence_length).long(), 0, sequence_length - 1
  140. )
  141. ocr_points_y = (
  142. torch.clip(torch.floor((bbox[:, :, 1] + bbox[:, :, 3]) / 2.0 * sequence_length).long(), 0, sequence_length - 1)
  143. * sequence_length
  144. )
  145. ocr_points = ocr_points_x + ocr_points_y
  146. # make sure bounding boxes are of type float to calculate means
  147. bbox = bbox.to(torch.float64)
  148. target_seg = (bbox.mean(-1) == 0.0) | (bbox.mean(-1) == 1.0)
  149. repeated_vision_embeds = torch.gather(
  150. image_embeddings, 1, ocr_points.unsqueeze(-1).repeat(1, 1, image_embeddings.size(-1))
  151. )
  152. repeated_vision_embeds[target_seg] = 0.0
  153. inputs_embeds += repeated_vision_embeds
  154. patch_inds = torch.full_like(image_embeddings[:, :, 0], True).bool()
  155. ind = torch.cat(
  156. [
  157. torch.arange(len(ocr_points))[:, None].repeat(1, ocr_points.size(-1))[:, :, None].to(ocr_points),
  158. ocr_points[:, :, None],
  159. ],
  160. dim=-1,
  161. )
  162. ind = ind.flatten(0, 1)
  163. rows, cols = zip(*ind)
  164. patch_inds[rows, cols] = False
  165. input_vision_patches = [image_embeddings[i][patch_inds[i]] for i in range(len(patch_inds))]
  166. if visual_bbox is None:
  167. visual_bbox = get_visual_bbox(image_size=image_size, patch_size=patch_size)
  168. visual_bbox = visual_bbox.unsqueeze(0).repeat(image_embeddings.size(0), 1, 1)
  169. visual_bbox = visual_bbox.to(image_embeddings.device)
  170. visual_bbox = [visual_bbox[i][patch_inds[i]] for i in range(len(patch_inds))]
  171. if attention_mask is not None:
  172. visual_attention_mask = [torch.tensor([1] * len(item)).to(attention_mask) for item in visual_bbox]
  173. if max_len == 0:
  174. max_len = image_embeddings.size(1)
  175. else:
  176. max_len = max_len - inputs_embeds.size(1)
  177. inputs_vision_patches = torch.stack(
  178. [pad_sequence(item, max_len, torch.zeros_like(image_embeddings[0, 0])) for item in input_vision_patches]
  179. )
  180. visual_bbox = torch.stack([pad_sequence(item, max_len, torch.zeros_like(bbox[0, 0])) for item in visual_bbox])
  181. if attention_mask is not None:
  182. visual_attention_mask = torch.stack(
  183. [pad_sequence(item, max_len, torch.zeros_like(attention_mask[0, 0])) for item in visual_attention_mask]
  184. )
  185. inputs_embeds = torch.cat([inputs_embeds, inputs_vision_patches], 1)
  186. bbox = torch.cat([bbox, visual_bbox], 1)
  187. if attention_mask is not None:
  188. attention_mask = torch.cat([attention_mask, visual_attention_mask], 1)
  189. return inputs_embeds, bbox, attention_mask
  190. class UdopPatchEmbeddings(nn.Module):
  191. """2D Image to Patch Embeddings"""
  192. def __init__(self, config):
  193. super().__init__()
  194. image_size, patch_size = config.image_size, config.patch_size
  195. num_channels, hidden_size = config.num_channels, config.hidden_size
  196. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  197. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  198. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  199. self.image_size = image_size
  200. self.patch_size = patch_size
  201. self.num_channels = num_channels
  202. self.num_patches = num_patches
  203. self.proj = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  204. def forward(self, pixel_values):
  205. batch_size, num_channels, height, width = pixel_values.shape
  206. if height != self.image_size[0] or width != self.image_size[1]:
  207. raise ValueError(
  208. f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
  209. )
  210. embeddings = self.proj(pixel_values)
  211. embeddings = embeddings.flatten(2).transpose(1, 2)
  212. return embeddings
  213. @auto_docstring
  214. class UdopPreTrainedModel(PreTrainedModel):
  215. config: UdopConfig
  216. base_model_prefix = "transformer"
  217. supports_gradient_checkpointing = True
  218. _can_compile_fullgraph = False
  219. _keep_in_fp32_modules = ["wo"]
  220. def _init_weights(self, module):
  221. """Initialize the weights"""
  222. factor = self.config.initializer_factor # Used for testing weights initialization
  223. if isinstance(module, UdopLayerNorm):
  224. module.weight.data.fill_(factor * 1.0)
  225. elif isinstance(module, nn.Embedding):
  226. module.weight.data.normal_(mean=0.0, std=factor)
  227. if module.padding_idx is not None:
  228. module.weight.data[module.padding_idx].zero_()
  229. elif isinstance(module, nn.Conv2d):
  230. # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
  231. # `trunc_normal_cpu` not implemented in `half` issues
  232. module.weight.data = nn.init.trunc_normal_(module.weight.data.to(torch.float32), mean=0.0, std=factor).to(
  233. module.weight.dtype
  234. )
  235. if module.bias is not None:
  236. module.bias.data.zero_()
  237. elif isinstance(module, RelativePositionBiasBase):
  238. factor = self.config.initializer_factor
  239. d_model = self.config.d_model
  240. module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
  241. elif isinstance(module, UdopModel):
  242. # Mesh TensorFlow embeddings initialization
  243. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
  244. module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
  245. elif isinstance(module, UdopForConditionalGeneration):
  246. if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
  247. module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
  248. elif isinstance(module, UdopDenseActDense):
  249. # Mesh TensorFlow FF initialization
  250. # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
  251. # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
  252. module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  253. if hasattr(module.wi, "bias") and module.wi.bias is not None:
  254. module.wi.bias.data.zero_()
  255. module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  256. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  257. module.wo.bias.data.zero_()
  258. elif isinstance(module, UdopDenseGatedActDense):
  259. module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  260. if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
  261. module.wi_0.bias.data.zero_()
  262. module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  263. if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
  264. module.wi_1.bias.data.zero_()
  265. module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  266. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  267. module.wo.bias.data.zero_()
  268. elif isinstance(module, UdopAttention):
  269. # Mesh TensorFlow attention initialization to avoid scaling before softmax
  270. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
  271. d_model = self.config.d_model
  272. key_value_proj_dim = self.config.d_kv
  273. n_heads = self.config.num_heads
  274. module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
  275. module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
  276. module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
  277. module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
  278. if module.has_relative_attention_bias:
  279. module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
  280. # Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetPreTrainedModel._shift_right with ProphetNet->Udop
  281. def _shift_right(self, input_ids):
  282. decoder_start_token_id = self.config.decoder_start_token_id
  283. pad_token_id = self.config.pad_token_id
  284. assert decoder_start_token_id is not None, (
  285. "self.model.config.decoder_start_token_id has to be defined. In Udop it is usually set to the"
  286. " pad_token_id. See Udop docs for more information"
  287. )
  288. # shift inputs to the right
  289. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  290. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  291. shifted_input_ids[..., 0] = decoder_start_token_id
  292. assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
  293. # replace possible -100 values in labels by `pad_token_id`
  294. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  295. assert torch.all(shifted_input_ids >= 0).item(), "Verify that `shifted_input_ids` has only positive values"
  296. return shifted_input_ids
  297. # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->Udop
  298. class UdopLayerNorm(nn.Module):
  299. def __init__(self, hidden_size, eps=1e-6):
  300. """
  301. Construct a layernorm module in the Udop style. No bias and no subtraction of mean.
  302. """
  303. super().__init__()
  304. self.weight = nn.Parameter(torch.ones(hidden_size))
  305. self.variance_epsilon = eps
  306. def forward(self, hidden_states):
  307. # Udop uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
  308. # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
  309. # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
  310. # half-precision inputs is done in fp32
  311. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  312. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  313. # convert into half-precision if necessary
  314. if self.weight.dtype in [torch.float16, torch.bfloat16]:
  315. hidden_states = hidden_states.to(self.weight.dtype)
  316. return self.weight * hidden_states
  317. # Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->Udop
  318. class UdopDenseActDense(nn.Module):
  319. def __init__(self, config: UdopConfig):
  320. super().__init__()
  321. self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
  322. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  323. self.dropout = nn.Dropout(config.dropout_rate)
  324. self.act = ACT2FN[config.dense_act_fn]
  325. def forward(self, hidden_states):
  326. hidden_states = self.wi(hidden_states)
  327. hidden_states = self.act(hidden_states)
  328. hidden_states = self.dropout(hidden_states)
  329. if (
  330. isinstance(self.wo.weight, torch.Tensor)
  331. and hidden_states.dtype != self.wo.weight.dtype
  332. and self.wo.weight.dtype != torch.int8
  333. ):
  334. hidden_states = hidden_states.to(self.wo.weight.dtype)
  335. hidden_states = self.wo(hidden_states)
  336. return hidden_states
  337. # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->Udop
  338. class UdopDenseGatedActDense(nn.Module):
  339. def __init__(self, config: UdopConfig):
  340. super().__init__()
  341. self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
  342. self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
  343. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  344. self.dropout = nn.Dropout(config.dropout_rate)
  345. self.act = ACT2FN[config.dense_act_fn]
  346. def forward(self, hidden_states):
  347. hidden_gelu = self.act(self.wi_0(hidden_states))
  348. hidden_linear = self.wi_1(hidden_states)
  349. hidden_states = hidden_gelu * hidden_linear
  350. hidden_states = self.dropout(hidden_states)
  351. # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
  352. # See https://github.com/huggingface/transformers/issues/20287
  353. # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
  354. if (
  355. isinstance(self.wo.weight, torch.Tensor)
  356. and hidden_states.dtype != self.wo.weight.dtype
  357. and self.wo.weight.dtype != torch.int8
  358. ):
  359. hidden_states = hidden_states.to(self.wo.weight.dtype)
  360. hidden_states = self.wo(hidden_states)
  361. return hidden_states
  362. # Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->Udop
  363. class UdopLayerFF(nn.Module):
  364. def __init__(self, config: UdopConfig):
  365. super().__init__()
  366. if config.is_gated_act:
  367. self.DenseReluDense = UdopDenseGatedActDense(config)
  368. else:
  369. self.DenseReluDense = UdopDenseActDense(config)
  370. self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  371. self.dropout = nn.Dropout(config.dropout_rate)
  372. def forward(self, hidden_states):
  373. forwarded_states = self.layer_norm(hidden_states)
  374. forwarded_states = self.DenseReluDense(forwarded_states)
  375. hidden_states = hidden_states + self.dropout(forwarded_states)
  376. return hidden_states
  377. # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Udop
  378. class UdopAttention(nn.Module):
  379. def __init__(
  380. self,
  381. config: UdopConfig,
  382. has_relative_attention_bias=False,
  383. layer_idx: Optional[int] = None,
  384. ):
  385. super().__init__()
  386. self.is_decoder = config.is_decoder
  387. self.has_relative_attention_bias = has_relative_attention_bias
  388. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  389. self.relative_attention_max_distance = config.relative_attention_max_distance
  390. self.d_model = config.d_model
  391. self.key_value_proj_dim = config.d_kv
  392. self.n_heads = config.num_heads
  393. self.dropout = config.dropout_rate
  394. self.inner_dim = self.n_heads * self.key_value_proj_dim
  395. self.layer_idx = layer_idx
  396. if layer_idx is None and self.is_decoder:
  397. logger.warning_once(
  398. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  399. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  400. "when creating this class."
  401. )
  402. # Mesh TensorFlow initialization to avoid scaling before softmax
  403. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  404. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  405. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  406. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  407. if self.has_relative_attention_bias:
  408. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  409. self.pruned_heads = set()
  410. self.gradient_checkpointing = False
  411. def prune_heads(self, heads):
  412. if len(heads) == 0:
  413. return
  414. heads, index = find_pruneable_heads_and_indices(
  415. heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
  416. )
  417. # Prune linear layers
  418. self.q = prune_linear_layer(self.q, index)
  419. self.k = prune_linear_layer(self.k, index)
  420. self.v = prune_linear_layer(self.v, index)
  421. self.o = prune_linear_layer(self.o, index, dim=1)
  422. # Update hyper params
  423. self.n_heads = self.n_heads - len(heads)
  424. self.inner_dim = self.key_value_proj_dim * self.n_heads
  425. self.pruned_heads = self.pruned_heads.union(heads)
  426. @staticmethod
  427. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  428. """
  429. Adapted from Mesh Tensorflow:
  430. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  431. Translate relative position to a bucket number for relative attention. The relative position is defined as
  432. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  433. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  434. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  435. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  436. This should allow for more graceful generalization to longer sequences than the model has been trained on
  437. Args:
  438. relative_position: an int32 Tensor
  439. bidirectional: a boolean - whether the attention is bidirectional
  440. num_buckets: an integer
  441. max_distance: an integer
  442. Returns:
  443. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  444. """
  445. relative_buckets = 0
  446. if bidirectional:
  447. num_buckets //= 2
  448. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  449. relative_position = torch.abs(relative_position)
  450. else:
  451. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  452. # now relative_position is in the range [0, inf)
  453. # half of the buckets are for exact increments in positions
  454. max_exact = num_buckets // 2
  455. is_small = relative_position < max_exact
  456. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  457. relative_position_if_large = max_exact + (
  458. torch.log(relative_position.float() / max_exact)
  459. / math.log(max_distance / max_exact)
  460. * (num_buckets - max_exact)
  461. ).to(torch.long)
  462. relative_position_if_large = torch.min(
  463. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  464. )
  465. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  466. return relative_buckets
  467. def compute_bias(self, query_length, key_length, device=None, cache_position=None):
  468. """Compute binned relative position bias"""
  469. if device is None:
  470. device = self.relative_attention_bias.weight.device
  471. if cache_position is None:
  472. context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
  473. else:
  474. context_position = cache_position[:, None].to(device)
  475. memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
  476. relative_position = memory_position - context_position # shape (query_length, key_length)
  477. relative_position_bucket = self._relative_position_bucket(
  478. relative_position, # shape (query_length, key_length)
  479. bidirectional=(not self.is_decoder),
  480. num_buckets=self.relative_attention_num_buckets,
  481. max_distance=self.relative_attention_max_distance,
  482. )
  483. values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
  484. values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
  485. return values
  486. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  487. def forward(
  488. self,
  489. hidden_states,
  490. mask=None,
  491. key_value_states=None,
  492. position_bias=None,
  493. past_key_values=None,
  494. layer_head_mask=None,
  495. query_length=None,
  496. use_cache=False,
  497. output_attentions=False,
  498. cache_position=None,
  499. ):
  500. """
  501. Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
  502. """
  503. # Input is (batch_size, seq_length, dim)
  504. # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
  505. batch_size, seq_length = hidden_states.shape[:2]
  506. # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
  507. is_cross_attention = key_value_states is not None
  508. query_states = self.q(hidden_states)
  509. query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  510. # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
  511. is_updated = False
  512. if isinstance(past_key_values, EncoderDecoderCache):
  513. is_updated = past_key_values.is_updated.get(self.layer_idx)
  514. if is_cross_attention:
  515. # after the first generated id, we can subsequently re-use all key/value_states from cache
  516. curr_past_key_value = past_key_values.cross_attention_cache
  517. else:
  518. curr_past_key_value = past_key_values.self_attention_cache
  519. else:
  520. curr_past_key_value = past_key_values
  521. current_states = key_value_states if is_cross_attention else hidden_states
  522. if is_cross_attention and past_key_values is not None and is_updated:
  523. # reuse k,v, cross_attentions
  524. key_states = curr_past_key_value.layers[self.layer_idx].keys
  525. value_states = curr_past_key_value.layers[self.layer_idx].values
  526. else:
  527. key_states = self.k(current_states)
  528. value_states = self.v(current_states)
  529. key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  530. value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  531. if past_key_values is not None:
  532. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  533. cache_position = cache_position if not is_cross_attention else None
  534. key_states, value_states = curr_past_key_value.update(
  535. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  536. )
  537. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  538. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  539. past_key_values.is_updated[self.layer_idx] = True
  540. # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  541. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  542. if position_bias is None:
  543. key_length = key_states.shape[-2]
  544. # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
  545. real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
  546. if not self.has_relative_attention_bias:
  547. position_bias = torch.zeros(
  548. (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
  549. )
  550. if self.gradient_checkpointing and self.training:
  551. position_bias.requires_grad = True
  552. else:
  553. position_bias = self.compute_bias(
  554. real_seq_length, key_length, device=scores.device, cache_position=cache_position
  555. )
  556. position_bias = position_bias[:, :, -seq_length:, :]
  557. if mask is not None:
  558. causal_mask = mask[:, :, :, : key_states.shape[-2]]
  559. position_bias = position_bias + causal_mask
  560. if self.pruned_heads:
  561. mask = torch.ones(position_bias.shape[1])
  562. mask[list(self.pruned_heads)] = 0
  563. position_bias_masked = position_bias[:, mask.bool()]
  564. else:
  565. position_bias_masked = position_bias
  566. scores += position_bias_masked
  567. # (batch_size, n_heads, seq_length, key_length)
  568. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  569. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  570. # Mask heads if we want to
  571. if layer_head_mask is not None:
  572. attn_weights = attn_weights * layer_head_mask
  573. attn_output = torch.matmul(attn_weights, value_states)
  574. attn_output = attn_output.transpose(1, 2).contiguous()
  575. attn_output = attn_output.view(batch_size, -1, self.inner_dim)
  576. attn_output = self.o(attn_output)
  577. outputs = (attn_output, position_bias)
  578. if output_attentions:
  579. outputs = outputs + (attn_weights,)
  580. return outputs
  581. # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Udop
  582. class UdopLayerSelfAttention(nn.Module):
  583. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  584. super().__init__()
  585. self.SelfAttention = UdopAttention(
  586. config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
  587. )
  588. self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  589. self.dropout = nn.Dropout(config.dropout_rate)
  590. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  591. def forward(
  592. self,
  593. hidden_states,
  594. attention_mask=None,
  595. position_bias=None,
  596. layer_head_mask=None,
  597. past_key_values=None,
  598. use_cache=False,
  599. output_attentions=False,
  600. cache_position=None,
  601. ):
  602. normed_hidden_states = self.layer_norm(hidden_states)
  603. attention_output = self.SelfAttention(
  604. normed_hidden_states,
  605. mask=attention_mask,
  606. position_bias=position_bias,
  607. layer_head_mask=layer_head_mask,
  608. past_key_values=past_key_values,
  609. use_cache=use_cache,
  610. output_attentions=output_attentions,
  611. cache_position=cache_position,
  612. )
  613. hidden_states = hidden_states + self.dropout(attention_output[0])
  614. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  615. return outputs
  616. # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Udop
  617. class UdopLayerCrossAttention(nn.Module):
  618. def __init__(self, config, layer_idx: Optional[int] = None):
  619. super().__init__()
  620. self.EncDecAttention = UdopAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
  621. self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  622. self.dropout = nn.Dropout(config.dropout_rate)
  623. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  624. def forward(
  625. self,
  626. hidden_states,
  627. key_value_states,
  628. attention_mask=None,
  629. position_bias=None,
  630. layer_head_mask=None,
  631. past_key_values=None,
  632. use_cache=False,
  633. query_length=None,
  634. output_attentions=False,
  635. cache_position=None,
  636. ):
  637. normed_hidden_states = self.layer_norm(hidden_states)
  638. attention_output = self.EncDecAttention(
  639. normed_hidden_states,
  640. mask=attention_mask,
  641. key_value_states=key_value_states,
  642. position_bias=position_bias,
  643. layer_head_mask=layer_head_mask,
  644. past_key_values=past_key_values,
  645. use_cache=use_cache,
  646. query_length=query_length,
  647. output_attentions=output_attentions,
  648. cache_position=cache_position,
  649. )
  650. layer_output = hidden_states + self.dropout(attention_output[0])
  651. outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
  652. return outputs
  653. # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Udop
  654. class UdopBlock(GradientCheckpointingLayer):
  655. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  656. super().__init__()
  657. self.is_decoder = config.is_decoder
  658. self.layer = nn.ModuleList()
  659. self.layer.append(
  660. UdopLayerSelfAttention(
  661. config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
  662. )
  663. )
  664. if self.is_decoder:
  665. self.layer.append(UdopLayerCrossAttention(config, layer_idx=layer_idx))
  666. self.layer.append(UdopLayerFF(config))
  667. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  668. def forward(
  669. self,
  670. hidden_states,
  671. attention_mask=None,
  672. position_bias=None,
  673. encoder_hidden_states=None,
  674. encoder_attention_mask=None,
  675. encoder_decoder_position_bias=None,
  676. layer_head_mask=None,
  677. cross_attn_layer_head_mask=None,
  678. past_key_values=None,
  679. use_cache=False,
  680. output_attentions=False,
  681. return_dict=True,
  682. cache_position=None,
  683. ):
  684. self_attention_outputs = self.layer[0](
  685. hidden_states,
  686. attention_mask=attention_mask,
  687. position_bias=position_bias,
  688. layer_head_mask=layer_head_mask,
  689. past_key_values=past_key_values,
  690. use_cache=use_cache,
  691. output_attentions=output_attentions,
  692. cache_position=cache_position,
  693. )
  694. hidden_states = self_attention_outputs[0]
  695. attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
  696. # clamp inf values to enable fp16 training
  697. if hidden_states.dtype == torch.float16:
  698. clamp_value = torch.where(
  699. torch.isinf(hidden_states).any(),
  700. torch.finfo(hidden_states.dtype).max - 1000,
  701. torch.finfo(hidden_states.dtype).max,
  702. )
  703. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  704. do_cross_attention = self.is_decoder and encoder_hidden_states is not None
  705. if do_cross_attention:
  706. cross_attention_outputs = self.layer[1](
  707. hidden_states,
  708. key_value_states=encoder_hidden_states,
  709. attention_mask=encoder_attention_mask,
  710. position_bias=encoder_decoder_position_bias,
  711. layer_head_mask=cross_attn_layer_head_mask,
  712. past_key_values=past_key_values,
  713. query_length=cache_position[-1] + 1,
  714. use_cache=use_cache,
  715. output_attentions=output_attentions,
  716. )
  717. hidden_states = cross_attention_outputs[0]
  718. # clamp inf values to enable fp16 training
  719. if hidden_states.dtype == torch.float16:
  720. clamp_value = torch.where(
  721. torch.isinf(hidden_states).any(),
  722. torch.finfo(hidden_states.dtype).max - 1000,
  723. torch.finfo(hidden_states.dtype).max,
  724. )
  725. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  726. # Keep cross-attention outputs and relative position weights
  727. attention_outputs = attention_outputs + cross_attention_outputs[1:]
  728. # Apply Feed Forward layer
  729. hidden_states = self.layer[-1](hidden_states)
  730. # clamp inf values to enable fp16 training
  731. if hidden_states.dtype == torch.float16:
  732. clamp_value = torch.where(
  733. torch.isinf(hidden_states).any(),
  734. torch.finfo(hidden_states.dtype).max - 1000,
  735. torch.finfo(hidden_states.dtype).max,
  736. )
  737. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  738. outputs = (hidden_states,)
  739. return (
  740. outputs + attention_outputs
  741. ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
  742. class UdopCellEmbeddings(nn.Module):
  743. def __init__(self, max_2d_position_embeddings=501, hidden_size=1024):
  744. super().__init__()
  745. self.max_2d_position_embeddings = max_2d_position_embeddings
  746. self.x_position_embeddings = nn.Embedding(max_2d_position_embeddings, hidden_size)
  747. self.y_position_embeddings = nn.Embedding(max_2d_position_embeddings, hidden_size)
  748. def forward(self, bbox):
  749. bbox = torch.clip(bbox, 0.0, 1.0)
  750. bbox = (bbox * (self.max_2d_position_embeddings - 1)).long()
  751. left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
  752. upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
  753. right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
  754. lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
  755. embeddings = (
  756. left_position_embeddings
  757. + upper_position_embeddings
  758. + right_position_embeddings
  759. + lower_position_embeddings
  760. )
  761. return embeddings
  762. # get function for bucket computation
  763. # protected member access seems to be lesser evil than copy paste whole function
  764. get_relative_position_bucket = UdopAttention._relative_position_bucket
  765. AUGMENTATION_RANGE = (0.80, 1.25)
  766. class RelativePositionBiasBase(nn.Module, ABC):
  767. """
  768. Base class of relative biases.
  769. Args:
  770. num_heads (`int`):
  771. Number of attention heads in the model, it will create embeddings of size `num_heads`, which will be added to the scores of each token pair.
  772. relative_attention_num_buckets (`int`, *optional*, defaults to 32):
  773. Pair token metric (distance in the sequence, distance in pixels etc.) will be bucketed, parameter is defining number of such
  774. buckets.
  775. bidirectional (`bool`, *optional*, defaults to `True`):
  776. Whether the distance should be bidirectional for a pair of tokens. If `False`, then distance(tok1, tok2) == distance(tok2, tok1).
  777. scaling_factor (`int`, *optional*, defaults to 1):
  778. Defining factor which will be used to scale relative distance.
  779. max_distance (`int`, *optional*, defaults to 128):
  780. All distances above this value will end up in the one/same bucket.
  781. augmentation (`bool`, *optional*, defaults to `False`):
  782. Whether to multiply relative distances by a random scalar.
  783. expand (`bool`, *optional*, defaults to `False`):
  784. Whether to expand an existing pretrained model with subsequent additions of prefix_bucket.
  785. """
  786. def __init__(
  787. self,
  788. num_heads=None,
  789. relative_attention_num_buckets=32,
  790. bidirectional=True,
  791. scaling_factor=1,
  792. max_distance=128,
  793. level="tokens",
  794. augmentation=False,
  795. prefix_bucket=False,
  796. expand=False,
  797. ):
  798. super().__init__()
  799. self.prefix_bucket = prefix_bucket
  800. self.augmentation = augmentation
  801. self.level = level
  802. self.max_distance = max_distance
  803. self.scaling_factor = scaling_factor
  804. self.bidirectional = bidirectional
  805. self.num_heads = num_heads
  806. self.expand = expand
  807. self.relative_attention_num_buckets = relative_attention_num_buckets
  808. extra_head = 2 if prefix_bucket and not self.expand else 0
  809. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets + extra_head, self.num_heads)
  810. @abstractmethod
  811. def prepare_input(
  812. self,
  813. attention_mask: Optional[Tensor] = None,
  814. bbox: Optional[dict[str, Any]] = None,
  815. ) -> Tensor:
  816. pass
  817. def get_bucket(self, attention_mask: Optional[Tensor] = None, bbox: Optional[dict[str, Any]] = None) -> Tensor:
  818. relative_position = self.prepare_input(attention_mask, bbox)
  819. rp_bucket: Tensor = get_relative_position_bucket(
  820. relative_position,
  821. bidirectional=self.bidirectional,
  822. num_buckets=self.relative_attention_num_buckets,
  823. max_distance=self.max_distance,
  824. )
  825. return rp_bucket
  826. def get_relative_position(self, positions):
  827. context_position = positions[:, :, None]
  828. memory_position = positions[:, None, :]
  829. relative_position = memory_position - context_position
  830. if self.augmentation and self.training:
  831. relative_position *= random.uniform(*AUGMENTATION_RANGE)
  832. relative_position *= self.scaling_factor
  833. return relative_position.to(torch.long)
  834. def forward(self, attention_mask: Optional[Tensor] = None, bbox: Optional[dict[str, Any]] = None) -> Tensor:
  835. # re-using pretrained model with subsequent addition of prefix_bucket
  836. if self.expand and self.prefix_bucket:
  837. new_bias = nn.Embedding(self.relative_attention_num_buckets + 2, self.num_heads)
  838. new_bias.weight.data[: self.relative_attention_num_buckets] = self.relative_attention_bias.weight.data
  839. new_bias.weight.data[self.relative_attention_num_buckets :] = 0.1
  840. self.relative_attention_bias = new_bias
  841. self.expand = False
  842. rp_bucket = self.get_bucket(attention_mask, bbox)
  843. if self.prefix_bucket:
  844. if rp_bucket.size(0) == 1 and attention_mask.size(0) > 1:
  845. rp_bucket = rp_bucket.repeat(attention_mask.size(0), 1, 1)
  846. # based on assumption that prefix bboxes are negative
  847. is_prefix = bbox[:, :, 1] < 0
  848. num_prefix = is_prefix.sum(-1)
  849. for idx, num_prefix_row in enumerate(num_prefix.cpu().numpy()):
  850. rp_bucket[idx, :num_prefix_row, num_prefix_row:] = self.relative_attention_num_buckets
  851. rp_bucket[idx, num_prefix_row:, :num_prefix_row] = self.relative_attention_num_buckets + 1
  852. values: Tensor = self.relative_attention_bias(rp_bucket)
  853. if values.dim() != 4:
  854. raise ValueError("Wrong dimension of values tensor")
  855. values = values.permute([0, 3, 1, 2])
  856. return values
  857. class RelativePositionBias1D(RelativePositionBiasBase):
  858. def __init__(self, scaling_factor=1, max_distance=128, **kwargs):
  859. """
  860. Reimplementation of T5 relative position bias. Distance between given tokens is their distance in the sequence.
  861. Parameters are the same as in base class
  862. """
  863. super().__init__(scaling_factor=scaling_factor, max_distance=max_distance, **kwargs)
  864. def prepare_input(self, attention_mask: Optional[Tensor] = None, bbox: Optional[dict[str, Any]] = None) -> Tensor:
  865. if self.scaling_factor != 1:
  866. raise ValueError("No need to scale 1d features")
  867. relative_position = self.get_relative_position(
  868. torch.arange(attention_mask.size(1), dtype=torch.long, device=attention_mask.device)[None, :]
  869. )
  870. return relative_position
  871. class RelativePositionBiasHorizontal(RelativePositionBiasBase):
  872. def __init__(self, scaling_factor=100, max_distance=100, **kwargs):
  873. """
  874. Represents in the bucket embeddings horizontal distance between two tokens. Parameters are the same as in base
  875. class
  876. """
  877. super().__init__(scaling_factor=scaling_factor, max_distance=max_distance, **kwargs)
  878. def prepare_input(self, attention_mask: Optional[Tensor] = None, bbox: Optional[dict[str, Any]] = None) -> Tensor:
  879. if not self.scaling_factor > 1.0:
  880. raise ValueError("Need to scale the values of bboxes, as there are in small (0,1) range")
  881. if bbox is None:
  882. raise ValueError("Bbox is required for horizontal relative position bias")
  883. # get x positions of left point of bbox
  884. horizontal_position: Tensor = bbox[:, :, [0, 2]].mean(dim=-1)
  885. return self.get_relative_position(horizontal_position)
  886. class RelativePositionBiasVertical(RelativePositionBiasBase):
  887. def __init__(self, scaling_factor=100, max_distance=100, **kwargs):
  888. """
  889. Represents in the bucket embeddings vertical distance between two tokens. Parameters are the same as in base
  890. class
  891. """
  892. super().__init__(scaling_factor=scaling_factor, max_distance=max_distance, **kwargs)
  893. def prepare_input(self, attention_mask: Optional[Tensor] = None, bbox: Optional[dict[str, Any]] = None) -> Tensor:
  894. if not self.scaling_factor > 1.0:
  895. raise ValueError("Need to scale the values of bboxes, as there are in small (0,1) range")
  896. if bbox is None:
  897. raise ValueError("Bbox is required for vertical relative position bias")
  898. # get y positions of middle of bbox
  899. vertical_position: Tensor = bbox[:, :, [1, 3]].mean(dim=-1)
  900. return self.get_relative_position(vertical_position)
  901. class RelativePositionBiasAggregated(nn.Module):
  902. def __init__(self, modules: Sequence[RelativePositionBiasBase]):
  903. """
  904. Class which sums up various computed biases.
  905. Args:
  906. modules (Sequence[RelativePositionBiasBase]):
  907. List of relative bias modules.
  908. """
  909. super().__init__()
  910. self.biases = nn.ModuleList(modules)
  911. def forward(
  912. self, attention_mask: Optional[Tensor] = None, bbox: Optional[dict[str, Any]] = None
  913. ) -> Union[float, Tensor]:
  914. output = 0.0
  915. for bias in self.biases: # type: ignore
  916. output = bias(attention_mask, bbox) + output
  917. return output
  918. BIAS_CLASSES = {
  919. "1d": RelativePositionBias1D,
  920. "horizontal": RelativePositionBiasHorizontal,
  921. "vertical": RelativePositionBiasVertical,
  922. }
  923. def create_relative_bias(config: UdopConfig) -> Sequence[RelativePositionBiasBase]:
  924. """
  925. Creates empty list or one/multiple relative biases.
  926. :param config: Model's configuration :return: Sequence with created bias modules.
  927. """
  928. bias_list = []
  929. if hasattr(config, "relative_bias_args"):
  930. for bias_kwargs_org in config.relative_bias_args:
  931. bias_kwargs = deepcopy(bias_kwargs_org)
  932. bias_type = bias_kwargs.pop("type")
  933. model_num_heads = config.num_heads if hasattr(config, "num_heads") else config.num_attention_heads
  934. if "num_heads" in bias_kwargs:
  935. if bias_kwargs["num_heads"] != model_num_heads:
  936. raise ValueError("Number of heads must match num of heads in the model")
  937. else:
  938. bias_kwargs["num_heads"] = model_num_heads
  939. bias_list.append(BIAS_CLASSES[bias_type](**bias_kwargs)) # type: ignore
  940. return bias_list
  941. class UdopStack(UdopPreTrainedModel):
  942. """
  943. This class is based on `T5Stack`, but modified to take into account the image modality as well as 2D position
  944. embeddings.
  945. """
  946. def __init__(self, config, embed_tokens=None, embed_patches=None):
  947. super().__init__(config)
  948. self.embed_tokens = embed_tokens
  949. self.embed_patches = embed_patches
  950. self.is_decoder = config.is_decoder
  951. self._max_length = config.max_length
  952. self.num_layers = config.num_layers
  953. self.block = nn.ModuleList(
  954. [UdopBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(self.num_layers)]
  955. )
  956. self.final_layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  957. self.dropout = nn.Dropout(config.dropout_rate)
  958. if not self.is_decoder:
  959. self.cell_2d_embedding = UdopCellEmbeddings(config.max_2d_position_embeddings, config.hidden_size)
  960. # get weights from encoder position bias
  961. self.relative_bias = self._get_relative_bias(config)
  962. def _tie_weights(self):
  963. for bias in self.relative_bias.biases:
  964. if isinstance(bias, RelativePositionBias1D):
  965. self._tie_or_clone_weights(
  966. bias.relative_attention_bias, self.block[0].layer[0].SelfAttention.relative_attention_bias
  967. )
  968. @staticmethod
  969. def _get_relative_bias(config: UdopConfig) -> RelativePositionBiasAggregated:
  970. relative_bias_list = create_relative_bias(config)
  971. return RelativePositionBiasAggregated(relative_bias_list)
  972. def get_output_embeddings(self):
  973. return self.embed_tokens
  974. def set_input_embeddings(self, new_embeddings):
  975. self.embed_tokens = new_embeddings
  976. def forward(
  977. self,
  978. input_ids=None,
  979. attention_mask=None,
  980. bbox=None,
  981. encoder_hidden_states=None,
  982. encoder_attention_mask=None,
  983. inputs_embeds=None,
  984. pixel_values=None,
  985. visual_bbox=None,
  986. image_embeddings=None,
  987. position_bias=None,
  988. head_mask=None,
  989. cross_attn_head_mask=None,
  990. past_key_values=None,
  991. use_cache=None,
  992. output_attentions=None,
  993. output_hidden_states=None,
  994. return_dict=None,
  995. cache_position=None,
  996. ):
  997. use_cache = use_cache if use_cache is not None else self.config.use_cache
  998. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  999. output_hidden_states = (
  1000. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1001. )
  1002. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1003. # input embeddings processing
  1004. if input_ids is not None and inputs_embeds is not None:
  1005. err_msg_prefix = "decoder_" if self.is_decoder else ""
  1006. raise ValueError(
  1007. f"You cannot specify both {err_msg_prefix}inputs and {err_msg_prefix}inputs_embeds at the same time"
  1008. )
  1009. elif input_ids is not None and torch.numel(input_ids) > 0:
  1010. input_shape = input_ids.size()
  1011. input_ids = input_ids.view(-1, input_shape[-1])
  1012. elif inputs_embeds is None and input_ids is not None and torch.numel(input_ids) == 0:
  1013. input_ids = torch.full((4, 1024), self.config.pad_token_id, device=input_ids.device, dtype=input_ids.dtype)
  1014. attention_mask = torch.zeros((4, 1024), device=input_ids.device, dtype=input_ids.dtype)
  1015. bbox = torch.zeros((4, 1024, 4), device=input_ids.device, dtype=input_ids.dtype)
  1016. input_shape = input_ids.size()
  1017. position_bias = torch.zeros_like(self.get_extended_attention_mask(attention_mask, input_shape))
  1018. # encoder_attention_mask = attention_mask
  1019. logger.warning("Empty batch")
  1020. elif inputs_embeds is not None:
  1021. input_shape = inputs_embeds.size()[:-1]
  1022. else:
  1023. err_msg_prefix = "decoder_" if self.is_decoder else ""
  1024. raise ValueError(f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds")
  1025. if inputs_embeds is None:
  1026. if self.embed_tokens is None:
  1027. raise ValueError("You have to initialize the model with valid token embeddings")
  1028. inputs_embeds = self.embed_tokens(input_ids)
  1029. if pixel_values is not None:
  1030. image_embeddings = self.embed_patches(pixel_values)
  1031. if image_embeddings is not None:
  1032. # combine visual and OCR text embeddings
  1033. num_patches = self.config.image_size // self.config.patch_size
  1034. inputs_embeds, bbox, attention_mask = combine_image_text_embeddings(
  1035. image_embeddings,
  1036. inputs_embeds,
  1037. bbox,
  1038. visual_bbox,
  1039. attention_mask,
  1040. num_patches,
  1041. 0,
  1042. self.config.image_size,
  1043. self.config.patch_size,
  1044. )
  1045. input_shape = inputs_embeds.size()[:-1]
  1046. if not self.is_decoder and bbox is not None:
  1047. inputs_embeds += self.cell_2d_embedding(bbox)
  1048. batch_size, seq_length = input_shape
  1049. if use_cache is True:
  1050. assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
  1051. if self.is_decoder:
  1052. if use_cache and past_key_values is None:
  1053. if self.config.is_encoder_decoder:
  1054. past_key_values = EncoderDecoderCache(
  1055. DynamicCache(config=self.config), DynamicCache(config=self.config)
  1056. )
  1057. else:
  1058. past_key_values = DynamicCache(config=self.config)
  1059. elif not self.is_decoder:
  1060. # do not pass cache object down the line for encoder stack
  1061. # it messes indexing later in decoder-stack because cache object is modified in-place
  1062. past_key_values = None
  1063. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  1064. if cache_position is None:
  1065. cache_position = torch.arange(
  1066. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  1067. )
  1068. if attention_mask is None and not is_torchdynamo_compiling():
  1069. # required mask seq length can be calculated via length of past cache
  1070. mask_seq_length = past_key_values_length + seq_length
  1071. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  1072. if self.config.is_decoder:
  1073. causal_mask = self._update_causal_mask(
  1074. attention_mask,
  1075. inputs_embeds,
  1076. cache_position,
  1077. past_key_values.self_attention_cache
  1078. if isinstance(past_key_values, EncoderDecoderCache)
  1079. else past_key_values,
  1080. output_attentions,
  1081. )
  1082. else:
  1083. causal_mask = attention_mask[:, None, None, :]
  1084. causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
  1085. causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
  1086. if self.is_decoder and encoder_attention_mask is not None:
  1087. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  1088. else:
  1089. encoder_extended_attention_mask = None
  1090. # Prepare head mask if needed
  1091. head_mask = self.get_head_mask(head_mask, self.num_layers)
  1092. all_hidden_states = () if output_hidden_states else None
  1093. all_attentions = () if output_attentions else None
  1094. all_cross_attentions = () if (output_attentions and self.is_decoder) else None
  1095. if self.is_decoder: # modified lines
  1096. position_bias = None
  1097. else:
  1098. position_bias = self.relative_bias(attention_mask=attention_mask, bbox=bbox)
  1099. position_bias = position_bias + causal_mask
  1100. encoder_decoder_position_bias = None
  1101. hidden_states = inputs_embeds
  1102. hidden_states = self.dropout(hidden_states)
  1103. for i, layer_module in enumerate(self.block):
  1104. if output_hidden_states:
  1105. all_hidden_states = all_hidden_states + (hidden_states,)
  1106. layer_outputs = layer_module(
  1107. hidden_states,
  1108. causal_mask,
  1109. position_bias,
  1110. encoder_hidden_states,
  1111. encoder_extended_attention_mask,
  1112. encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
  1113. layer_head_mask=head_mask[i],
  1114. past_key_values=past_key_values,
  1115. use_cache=use_cache,
  1116. output_attentions=output_attentions,
  1117. cache_position=cache_position,
  1118. )
  1119. hidden_states = layer_outputs[0]
  1120. # We share the position biases between the layers - the first layer store them
  1121. # layer_outputs = hidden-states, key-value-states (self-attention weights),
  1122. # (self-attention position bias), (cross-attention weights), (cross-attention position bias)
  1123. position_bias = layer_outputs[1]
  1124. if self.is_decoder and encoder_hidden_states is not None:
  1125. encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
  1126. if output_attentions:
  1127. all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now
  1128. if self.is_decoder:
  1129. all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
  1130. hidden_states = self.final_layer_norm(hidden_states)
  1131. hidden_states = self.dropout(hidden_states)
  1132. # Add last layer
  1133. if output_hidden_states:
  1134. all_hidden_states = all_hidden_states + (hidden_states,)
  1135. if not return_dict:
  1136. return tuple(
  1137. v
  1138. for v in [
  1139. hidden_states,
  1140. attention_mask,
  1141. past_key_values,
  1142. all_hidden_states,
  1143. all_attentions,
  1144. all_cross_attentions,
  1145. ]
  1146. if v is not None
  1147. )
  1148. return BaseModelOutputWithAttentionMask(
  1149. last_hidden_state=hidden_states,
  1150. attention_mask=attention_mask,
  1151. past_key_values=past_key_values,
  1152. hidden_states=all_hidden_states,
  1153. attentions=all_attentions,
  1154. cross_attentions=all_cross_attentions,
  1155. )
  1156. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
  1157. def _update_causal_mask(
  1158. self,
  1159. attention_mask: Union[torch.Tensor, "BlockMask"],
  1160. input_tensor: torch.Tensor,
  1161. cache_position: torch.Tensor,
  1162. past_key_values: Cache,
  1163. output_attentions: bool = False,
  1164. ):
  1165. if self.config._attn_implementation == "flash_attention_2":
  1166. if attention_mask is not None and (attention_mask == 0.0).any():
  1167. return attention_mask
  1168. return None
  1169. if self.config._attn_implementation == "flex_attention":
  1170. if isinstance(attention_mask, torch.Tensor):
  1171. attention_mask = make_flex_block_causal_mask(attention_mask)
  1172. return attention_mask
  1173. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  1174. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  1175. # to infer the attention mask.
  1176. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1177. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  1178. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  1179. if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
  1180. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  1181. attention_mask,
  1182. inputs_embeds=input_tensor,
  1183. past_key_values_length=past_seen_tokens,
  1184. is_training=self.training,
  1185. ):
  1186. return None
  1187. dtype = input_tensor.dtype
  1188. sequence_length = input_tensor.shape[1]
  1189. if using_compilable_cache:
  1190. target_length = past_key_values.get_max_cache_shape()
  1191. else:
  1192. target_length = (
  1193. attention_mask.shape[-1]
  1194. if isinstance(attention_mask, torch.Tensor)
  1195. else past_seen_tokens + sequence_length + 1
  1196. )
  1197. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  1198. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  1199. attention_mask,
  1200. sequence_length=sequence_length,
  1201. target_length=target_length,
  1202. dtype=dtype,
  1203. cache_position=cache_position,
  1204. batch_size=input_tensor.shape[0],
  1205. )
  1206. if (
  1207. self.config._attn_implementation == "sdpa"
  1208. and attention_mask is not None
  1209. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  1210. and not output_attentions
  1211. ):
  1212. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  1213. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  1214. # Details: https://github.com/pytorch/pytorch/issues/110213
  1215. min_dtype = torch.finfo(dtype).min
  1216. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  1217. return causal_mask
  1218. @staticmethod
  1219. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  1220. def _prepare_4d_causal_attention_mask_with_cache_position(
  1221. attention_mask: torch.Tensor,
  1222. sequence_length: int,
  1223. target_length: int,
  1224. dtype: torch.dtype,
  1225. cache_position: torch.Tensor,
  1226. batch_size: int,
  1227. **kwargs,
  1228. ):
  1229. """
  1230. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  1231. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  1232. Args:
  1233. attention_mask (`torch.Tensor`):
  1234. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  1235. `(batch_size, 1, query_length, key_value_length)`.
  1236. sequence_length (`int`):
  1237. The sequence length being processed.
  1238. target_length (`int`):
  1239. The target length: when generating with static cache, the mask should be as long as the static cache,
  1240. to account for the 0 padding, the part of the cache that is not filled yet.
  1241. dtype (`torch.dtype`):
  1242. The dtype to use for the 4D attention mask.
  1243. cache_position (`torch.Tensor`):
  1244. Indices depicting the position of the input sequence tokens in the sequence.
  1245. batch_size (`torch.Tensor`):
  1246. Batch size.
  1247. """
  1248. if attention_mask is not None and attention_mask.dim() == 4:
  1249. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  1250. causal_mask = attention_mask
  1251. else:
  1252. min_dtype = torch.finfo(dtype).min
  1253. causal_mask = torch.full(
  1254. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  1255. )
  1256. if sequence_length != 1:
  1257. causal_mask = torch.triu(causal_mask, diagonal=1)
  1258. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  1259. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  1260. if attention_mask is not None:
  1261. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  1262. mask_length = attention_mask.shape[-1]
  1263. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  1264. causal_mask.device
  1265. )
  1266. padding_mask = padding_mask == 0
  1267. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  1268. padding_mask, min_dtype
  1269. )
  1270. return causal_mask
  1271. @auto_docstring
  1272. class UdopModel(UdopPreTrainedModel):
  1273. _tied_weights_keys = [
  1274. "encoder.embed_tokens.weight",
  1275. "decoder.embed_tokens.weight",
  1276. "encoder.embed_patches.proj.weight",
  1277. "encoder.embed_patches.proj.bias",
  1278. "encoder.relative_bias.biases.0.relative_attention_bias.weight",
  1279. "decoder.relative_bias.biases.0.relative_attention_bias.weight",
  1280. ]
  1281. def __init__(self, config):
  1282. super().__init__(config)
  1283. # text and image embeddings
  1284. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1285. self.patch_embed = UdopPatchEmbeddings(config)
  1286. encoder_config = deepcopy(config)
  1287. encoder_config.is_decoder = False
  1288. encoder_config.use_cache = False
  1289. encoder_config.tie_encoder_decoder = False
  1290. self.encoder = UdopStack(encoder_config, self.shared, self.patch_embed)
  1291. decoder_config = deepcopy(config)
  1292. decoder_config.is_decoder = True
  1293. decoder_config.tie_encoder_decoder = False
  1294. decoder_config.num_layers = config.num_decoder_layers
  1295. self.decoder = UdopStack(decoder_config, self.shared)
  1296. # Initialize weights and apply final processing
  1297. self.post_init()
  1298. def get_input_embeddings(self):
  1299. return self.shared
  1300. def set_input_embeddings(self, new_embeddings):
  1301. self.shared = new_embeddings
  1302. self.encoder.set_input_embeddings(new_embeddings)
  1303. self.decoder.set_input_embeddings(new_embeddings)
  1304. def get_encoder(self):
  1305. return self.encoder
  1306. @auto_docstring
  1307. def forward(
  1308. self,
  1309. input_ids: Optional[Tensor] = None,
  1310. attention_mask: Optional[Tensor] = None,
  1311. bbox: Optional[dict[str, Any]] = None,
  1312. pixel_values: Optional[Tensor] = None,
  1313. visual_bbox: Optional[dict[str, Any]] = None,
  1314. decoder_input_ids: Optional[Tensor] = None,
  1315. decoder_attention_mask: Optional[Tensor] = None,
  1316. inputs_embeds: Optional[Tensor] = None,
  1317. encoder_outputs: Optional[Tensor] = None,
  1318. past_key_values: Optional[Cache] = None,
  1319. head_mask: Optional[Tensor] = None,
  1320. decoder_inputs_embeds: Optional[Tensor] = None,
  1321. decoder_head_mask: Optional[Tensor] = None,
  1322. cross_attn_head_mask: Optional[Tensor] = None,
  1323. use_cache=True,
  1324. output_attentions: Optional[bool] = None,
  1325. output_hidden_states: Optional[bool] = None,
  1326. return_dict: Optional[bool] = None,
  1327. cache_position: Optional[torch.LongTensor] = None,
  1328. ) -> tuple[Tensor, ...]:
  1329. r"""
  1330. bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
  1331. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  1332. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  1333. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  1334. y1) represents the position of the lower right corner.
  1335. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  1336. token. See `pixel_values` for `patch_sequence_length`.
  1337. visual_bbox (`torch.LongTensor` of shape `(batch_size, patch_sequence_length, 4)`, *optional*):
  1338. Bounding boxes of each patch in the image. If not provided, bounding boxes are created in the model.
  1339. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1340. Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
  1341. [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  1342. [What are decoder input IDs?](../glossary#decoder-input-ids) T5 uses the `pad_token_id` as the starting
  1343. token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last
  1344. `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare
  1345. `decoder_input_ids` for pretraining take a look at [T5 Training](./t5#training).
  1346. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1347. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1348. be used by default.
  1349. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1350. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1351. 1]`:
  1352. - 1 indicates the head is **not masked**,
  1353. - 0 indicates the head is **masked**.
  1354. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1355. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1356. `[0, 1]`:
  1357. - 1 indicates the head is **not masked**,
  1358. - 0 indicates the head is **masked**.
  1359. Example:
  1360. ```python
  1361. >>> from transformers import AutoProcessor, AutoModel
  1362. >>> from datasets import load_dataset
  1363. >>> import torch
  1364. >>> # load model and processor
  1365. >>> # in this case, we already have performed OCR ourselves
  1366. >>> # so we initialize the processor with `apply_ocr=False`
  1367. >>> processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False)
  1368. >>> model = AutoModel.from_pretrained("microsoft/udop-large")
  1369. >>> # load an example image, along with the words and coordinates
  1370. >>> # which were extracted using an OCR engine
  1371. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  1372. >>> example = dataset[0]
  1373. >>> image = example["image"]
  1374. >>> words = example["tokens"]
  1375. >>> boxes = example["bboxes"]
  1376. >>> inputs = processor(image, words, boxes=boxes, return_tensors="pt")
  1377. >>> decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]])
  1378. >>> # forward pass
  1379. >>> outputs = model(**inputs, decoder_input_ids=decoder_input_ids)
  1380. >>> last_hidden_states = outputs.last_hidden_state
  1381. >>> list(last_hidden_states.shape)
  1382. [1, 1, 1024]
  1383. ```"""
  1384. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1385. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1386. # Encode if needed (training, first prediction pass)
  1387. if encoder_outputs is None:
  1388. encoder_outputs = self.encoder(
  1389. input_ids=input_ids,
  1390. attention_mask=attention_mask,
  1391. bbox=bbox,
  1392. pixel_values=pixel_values,
  1393. visual_bbox=visual_bbox,
  1394. inputs_embeds=inputs_embeds,
  1395. head_mask=head_mask,
  1396. output_attentions=output_attentions,
  1397. output_hidden_states=output_hidden_states,
  1398. return_dict=return_dict,
  1399. )
  1400. hidden_states = encoder_outputs[0]
  1401. encoder_attention_mask = encoder_outputs.attention_mask if return_dict else encoder_outputs[1]
  1402. # Decode
  1403. decoder_outputs = self.decoder(
  1404. input_ids=decoder_input_ids,
  1405. attention_mask=decoder_attention_mask,
  1406. inputs_embeds=decoder_inputs_embeds,
  1407. past_key_values=past_key_values,
  1408. encoder_hidden_states=hidden_states,
  1409. encoder_attention_mask=encoder_attention_mask,
  1410. head_mask=decoder_head_mask,
  1411. cross_attn_head_mask=cross_attn_head_mask,
  1412. use_cache=use_cache,
  1413. output_attentions=output_attentions,
  1414. output_hidden_states=output_hidden_states,
  1415. return_dict=return_dict,
  1416. cache_position=cache_position,
  1417. )
  1418. if not return_dict:
  1419. # we filter out the attention mask
  1420. decoder_outputs = tuple(value for idx, value in enumerate(decoder_outputs) if idx != 1)
  1421. encoder_outputs = tuple(value for idx, value in enumerate(encoder_outputs) if idx != 1)
  1422. return decoder_outputs + encoder_outputs
  1423. return Seq2SeqModelOutput(
  1424. last_hidden_state=decoder_outputs.last_hidden_state,
  1425. past_key_values=decoder_outputs.past_key_values,
  1426. decoder_hidden_states=decoder_outputs.hidden_states,
  1427. decoder_attentions=decoder_outputs.attentions,
  1428. cross_attentions=decoder_outputs.cross_attentions,
  1429. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1430. encoder_hidden_states=encoder_outputs.hidden_states,
  1431. encoder_attentions=encoder_outputs.attentions,
  1432. )
  1433. @auto_docstring(
  1434. custom_intro="""
  1435. The UDOP encoder-decoder Transformer with a language modeling head on top, enabling to generate text given document
  1436. images and an optional prompt.
  1437. This class is based on [`T5ForConditionalGeneration`], extended to deal with images and layout (2D) data.
  1438. """
  1439. )
  1440. class UdopForConditionalGeneration(UdopPreTrainedModel, GenerationMixin):
  1441. _tied_weights_keys = [
  1442. "encoder.embed_tokens.weight",
  1443. "decoder.embed_tokens.weight",
  1444. "encoder.embed_patches.proj.weight",
  1445. "encoder.embed_patches.proj.bias",
  1446. "encoder.relative_bias.biases.0.relative_attention_bias.weight",
  1447. "decoder.relative_bias.biases.0.relative_attention_bias.weight",
  1448. "lm_head.weight",
  1449. ]
  1450. def __init__(self, config):
  1451. super().__init__(config)
  1452. # text and image embeddings
  1453. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1454. self.patch_embed = UdopPatchEmbeddings(config)
  1455. encoder_config = deepcopy(config)
  1456. encoder_config.is_decoder = False
  1457. encoder_config.use_cache = False
  1458. encoder_config.tie_encoder_decoder = False
  1459. self.encoder = UdopStack(encoder_config, self.shared, self.patch_embed)
  1460. decoder_config = deepcopy(config)
  1461. decoder_config.is_decoder = True
  1462. decoder_config.tie_encoder_decoder = False
  1463. decoder_config.num_layers = config.num_decoder_layers
  1464. self.decoder = UdopStack(decoder_config, self.shared)
  1465. # The weights of the language modeling head are shared with those of the encoder and decoder
  1466. self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
  1467. # Initialize weights and apply final processing
  1468. self.post_init()
  1469. def get_input_embeddings(self):
  1470. return self.shared
  1471. def set_input_embeddings(self, new_embeddings):
  1472. self.shared = new_embeddings
  1473. self.encoder.set_input_embeddings(new_embeddings)
  1474. self.decoder.set_input_embeddings(new_embeddings)
  1475. def get_encoder(self):
  1476. return self.encoder
  1477. @auto_docstring
  1478. def forward(
  1479. self,
  1480. input_ids: Optional[Tensor] = None,
  1481. attention_mask: Optional[Tensor] = None,
  1482. bbox: Optional[dict[str, Any]] = None,
  1483. pixel_values: Optional[Tensor] = None,
  1484. visual_bbox: Optional[dict[str, Any]] = None,
  1485. decoder_input_ids: Optional[Tensor] = None,
  1486. decoder_attention_mask: Optional[Tensor] = None,
  1487. inputs_embeds: Optional[Tensor] = None,
  1488. encoder_outputs: Optional[Tensor] = None,
  1489. past_key_values: Optional[Cache] = None,
  1490. head_mask: Optional[Tensor] = None,
  1491. decoder_inputs_embeds: Optional[Tensor] = None,
  1492. decoder_head_mask: Optional[Tensor] = None,
  1493. cross_attn_head_mask: Optional[Tensor] = None,
  1494. use_cache=True,
  1495. output_attentions: Optional[bool] = None,
  1496. output_hidden_states: Optional[bool] = None,
  1497. return_dict: Optional[bool] = None,
  1498. labels: Optional[Tensor] = None,
  1499. cache_position: Optional[torch.LongTensor] = None,
  1500. ) -> tuple[Tensor, ...]:
  1501. r"""
  1502. bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
  1503. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  1504. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  1505. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  1506. y1) represents the position of the lower right corner.
  1507. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  1508. token. See `pixel_values` for `patch_sequence_length`.
  1509. visual_bbox (`torch.LongTensor` of shape `(batch_size, patch_sequence_length, 4)`, *optional*):
  1510. Bounding boxes of each patch in the image. If not provided, bounding boxes are created in the model.
  1511. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1512. Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
  1513. [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  1514. [What are decoder input IDs?](../glossary#decoder-input-ids) T5 uses the `pad_token_id` as the starting
  1515. token for `decoder_input_ids` generation. If `past_key_values` is used, optionally only the last
  1516. `decoder_input_ids` have to be input (see `past_key_values`). To know more on how to prepare
  1517. `decoder_input_ids` for pretraining take a look at [T5 Training](./t5#training).
  1518. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1519. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1520. be used by default.
  1521. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1522. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1523. 1]`:
  1524. - 1 indicates the head is **not masked**,
  1525. - 0 indicates the head is **masked**.
  1526. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1527. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1528. `[0, 1]`:
  1529. - 1 indicates the head is **not masked**,
  1530. - 0 indicates the head is **masked**.
  1531. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1532. Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size -
  1533. 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
  1534. config.vocab_size]`.
  1535. Examples:
  1536. ```python
  1537. >>> from transformers import AutoProcessor, UdopForConditionalGeneration
  1538. >>> from datasets import load_dataset
  1539. >>> # load model and processor
  1540. >>> # in this case, we already have performed OCR ourselves
  1541. >>> # so we initialize the processor with `apply_ocr=False`
  1542. >>> processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False)
  1543. >>> model = UdopForConditionalGeneration.from_pretrained("microsoft/udop-large")
  1544. >>> # load an example image, along with the words and coordinates
  1545. >>> # which were extracted using an OCR engine
  1546. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  1547. >>> example = dataset[0]
  1548. >>> image = example["image"]
  1549. >>> words = example["tokens"]
  1550. >>> boxes = example["bboxes"]
  1551. >>> # one can use the various task prefixes (prompts) used during pre-training
  1552. >>> # e.g. the task prefix for DocVQA is "Question answering. "
  1553. >>> question = "Question answering. What is the date on the form?"
  1554. >>> encoding = processor(image, question, text_pair=words, boxes=boxes, return_tensors="pt")
  1555. >>> # autoregressive generation
  1556. >>> predicted_ids = model.generate(**encoding)
  1557. >>> print(processor.batch_decode(predicted_ids, skip_special_tokens=True)[0])
  1558. 9/30/92
  1559. ```"""
  1560. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1561. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1562. if decoder_input_ids is None and labels is not None:
  1563. decoder_input_ids = self._shift_right(labels)
  1564. # Encode if needed (training, first prediction pass)
  1565. if encoder_outputs is None:
  1566. encoder_outputs = self.encoder(
  1567. input_ids=input_ids,
  1568. bbox=bbox,
  1569. visual_bbox=visual_bbox,
  1570. pixel_values=pixel_values,
  1571. attention_mask=attention_mask,
  1572. inputs_embeds=inputs_embeds,
  1573. head_mask=head_mask,
  1574. output_attentions=output_attentions,
  1575. output_hidden_states=output_hidden_states,
  1576. return_dict=return_dict,
  1577. )
  1578. hidden_states = encoder_outputs[0]
  1579. encoder_attention_mask = encoder_outputs.attention_mask if return_dict else encoder_outputs[1]
  1580. # Decode
  1581. decoder_outputs = self.decoder(
  1582. input_ids=decoder_input_ids,
  1583. attention_mask=decoder_attention_mask,
  1584. inputs_embeds=decoder_inputs_embeds,
  1585. past_key_values=past_key_values,
  1586. encoder_hidden_states=hidden_states,
  1587. encoder_attention_mask=encoder_attention_mask,
  1588. head_mask=decoder_head_mask,
  1589. cross_attn_head_mask=cross_attn_head_mask,
  1590. use_cache=use_cache,
  1591. output_attentions=output_attentions,
  1592. output_hidden_states=output_hidden_states,
  1593. return_dict=return_dict,
  1594. cache_position=cache_position,
  1595. )
  1596. sequence_output = decoder_outputs[0]
  1597. if self.config.tie_word_embeddings:
  1598. # Rescale output before projecting on vocab
  1599. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
  1600. sequence_output = sequence_output * (self.config.d_model**-0.5)
  1601. lm_logits = self.lm_head(sequence_output)
  1602. loss = None
  1603. if labels is not None:
  1604. loss_fct = CrossEntropyLoss(ignore_index=-100)
  1605. loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
  1606. if not return_dict:
  1607. output = (lm_logits,) + decoder_outputs[2:] + (encoder_outputs[0],) + encoder_outputs[2:]
  1608. return ((loss,) + output) if loss is not None else output
  1609. return Seq2SeqLMOutput(
  1610. loss=loss,
  1611. logits=lm_logits,
  1612. past_key_values=decoder_outputs.past_key_values,
  1613. decoder_hidden_states=decoder_outputs.hidden_states,
  1614. decoder_attentions=decoder_outputs.attentions,
  1615. cross_attentions=decoder_outputs.cross_attentions,
  1616. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1617. encoder_hidden_states=encoder_outputs.hidden_states,
  1618. encoder_attentions=encoder_outputs.attentions,
  1619. )
  1620. @auto_docstring
  1621. class UdopEncoderModel(UdopPreTrainedModel):
  1622. _tied_weights_keys = [
  1623. "encoder.embed_tokens.weight",
  1624. "encoder.embed_patches.proj.weight",
  1625. "encoder.embed_patches.proj.bias",
  1626. "encoder.relative_bias.biases.0.relative_attention_bias.weight",
  1627. ]
  1628. def __init__(self, config: UdopConfig):
  1629. super().__init__(config)
  1630. # text and image embeddings
  1631. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1632. self.patch_embed = UdopPatchEmbeddings(config)
  1633. encoder_config = deepcopy(config)
  1634. encoder_config.is_decoder = False
  1635. encoder_config.use_cache = False
  1636. encoder_config.is_encoder_decoder = False
  1637. self.encoder = UdopStack(encoder_config, self.shared, self.patch_embed)
  1638. # Initialize weights and apply final processing
  1639. self.post_init()
  1640. def get_input_embeddings(self):
  1641. return self.shared
  1642. def set_input_embeddings(self, new_embeddings):
  1643. self.shared = new_embeddings
  1644. self.encoder.set_input_embeddings(new_embeddings)
  1645. def get_encoder(self):
  1646. return self.encoder
  1647. def _prune_heads(self, heads_to_prune):
  1648. """
  1649. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  1650. class PreTrainedModel
  1651. """
  1652. for layer, heads in heads_to_prune.items():
  1653. self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
  1654. @auto_docstring
  1655. def forward(
  1656. self,
  1657. input_ids: Optional[Tensor] = None,
  1658. bbox: Optional[dict[str, Any]] = None,
  1659. attention_mask: Optional[Tensor] = None,
  1660. pixel_values: Optional[Tensor] = None,
  1661. visual_bbox: Optional[dict[str, Any]] = None,
  1662. head_mask: Optional[Tensor] = None,
  1663. inputs_embeds: Optional[Tensor] = None,
  1664. output_attentions: Optional[bool] = None,
  1665. output_hidden_states: Optional[bool] = None,
  1666. return_dict: Optional[bool] = None,
  1667. ) -> Union[tuple[torch.FloatTensor], BaseModelOutputWithAttentionMask]:
  1668. r"""
  1669. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1670. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  1671. should be able to pad the inputs on both the right and the left.
  1672. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1673. [`PreTrainedTokenizer.__call__`] for detail.
  1674. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  1675. bbox (`torch.LongTensor` of shape `({0}, 4)`, *optional*):
  1676. Bounding boxes of each input sequence tokens. Selected in the range `[0,
  1677. config.max_2d_position_embeddings-1]`. Each bounding box should be a normalized version in (x0, y0, x1, y1)
  1678. format, where (x0, y0) corresponds to the position of the upper left corner in the bounding box, and (x1,
  1679. y1) represents the position of the lower right corner.
  1680. Note that `sequence_length = token_sequence_length + patch_sequence_length + 1` where `1` is for [CLS]
  1681. token. See `pixel_values` for `patch_sequence_length`.
  1682. visual_bbox (`torch.LongTensor` of shape `(batch_size, patch_sequence_length, 4)`, *optional*):
  1683. Bounding boxes of each patch in the image. If not provided, bounding boxes are created in the model.
  1684. Example:
  1685. ```python
  1686. >>> from transformers import AutoProcessor, UdopEncoderModel
  1687. >>> from huggingface_hub import hf_hub_download
  1688. >>> from datasets import load_dataset
  1689. >>> # load model and processor
  1690. >>> # in this case, we already have performed OCR ourselves
  1691. >>> # so we initialize the processor with `apply_ocr=False`
  1692. >>> processor = AutoProcessor.from_pretrained("microsoft/udop-large", apply_ocr=False)
  1693. >>> model = UdopEncoderModel.from_pretrained("microsoft/udop-large")
  1694. >>> # load an example image, along with the words and coordinates
  1695. >>> # which were extracted using an OCR engine
  1696. >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
  1697. >>> example = dataset[0]
  1698. >>> image = example["image"]
  1699. >>> words = example["tokens"]
  1700. >>> boxes = example["bboxes"]
  1701. >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt")
  1702. >>> outputs = model(**encoding)
  1703. >>> last_hidden_states = outputs.last_hidden_state
  1704. ```"""
  1705. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1706. output_hidden_states = (
  1707. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1708. )
  1709. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1710. encoder_outputs = self.encoder(
  1711. input_ids=input_ids,
  1712. bbox=bbox,
  1713. visual_bbox=visual_bbox,
  1714. pixel_values=pixel_values,
  1715. attention_mask=attention_mask,
  1716. inputs_embeds=inputs_embeds,
  1717. head_mask=head_mask,
  1718. output_attentions=output_attentions,
  1719. output_hidden_states=output_hidden_states,
  1720. return_dict=return_dict,
  1721. )
  1722. return encoder_outputs
  1723. __all__ = ["UdopForConditionalGeneration", "UdopPreTrainedModel", "UdopModel", "UdopEncoderModel"]