rec_donut_swin.py 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/huggingface/transformers/blob/main/src/transformers/models/donut/modeling_donut_swin.py
  17. """
  18. import collections.abc
  19. from collections import OrderedDict
  20. import math
  21. import numpy as np
  22. from dataclasses import dataclass
  23. from typing import Optional, Tuple, Union
  24. import paddle
  25. from paddle import nn
  26. import paddle.nn.functional as F
  27. from paddle.nn.initializer import (
  28. TruncatedNormal,
  29. Constant,
  30. Normal,
  31. KaimingUniform,
  32. XavierUniform,
  33. )
  34. zeros_ = Constant(value=0.0)
  35. ones_ = Constant(value=1.0)
  36. kaiming_normal_ = KaimingUniform(nonlinearity="relu")
  37. trunc_normal_ = TruncatedNormal(std=0.02)
  38. xavier_uniform_ = XavierUniform()
  39. # General docstring
  40. _CONFIG_FOR_DOC = "DonutSwinConfig"
  41. # Base docstring
  42. _CHECKPOINT_FOR_DOC = "https://huggingface.co/naver-clova-ix/donut-base"
  43. _EXPECTED_OUTPUT_SHAPE = [1, 49, 768]
  44. class DonutSwinConfig(object):
  45. model_type = "donut-swin"
  46. attribute_map = {
  47. "num_attention_heads": "num_heads",
  48. "num_hidden_layers": "num_layers",
  49. }
  50. def __init__(
  51. self,
  52. image_size=224,
  53. patch_size=4,
  54. num_channels=3,
  55. embed_dim=96,
  56. depths=[2, 2, 6, 2],
  57. num_heads=[3, 6, 12, 24],
  58. window_size=7,
  59. mlp_ratio=4.0,
  60. qkv_bias=True,
  61. hidden_dropout_prob=0.0,
  62. attention_probs_dropout_prob=0.0,
  63. drop_path_rate=0.1,
  64. hidden_act="gelu",
  65. use_absolute_embeddings=False,
  66. initializer_range=0.02,
  67. layer_norm_eps=1e-5,
  68. **kwargs,
  69. ):
  70. super().__init__()
  71. self.image_size = image_size
  72. self.patch_size = patch_size
  73. self.num_channels = num_channels
  74. self.embed_dim = embed_dim
  75. self.depths = depths
  76. self.num_layers = len(depths)
  77. self.num_heads = num_heads
  78. self.window_size = window_size
  79. self.mlp_ratio = mlp_ratio
  80. self.qkv_bias = qkv_bias
  81. self.hidden_dropout_prob = hidden_dropout_prob
  82. self.attention_probs_dropout_prob = attention_probs_dropout_prob
  83. self.drop_path_rate = drop_path_rate
  84. self.hidden_act = hidden_act
  85. self.use_absolute_embeddings = use_absolute_embeddings
  86. self.layer_norm_eps = layer_norm_eps
  87. self.initializer_range = initializer_range
  88. self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1))
  89. for key, value in kwargs.items():
  90. try:
  91. setattr(self, key, value)
  92. except AttributeError as err:
  93. print(f"Can't set {key} with value {value} for {self}")
  94. raise err
  95. @dataclass
  96. # Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->DonutSwin
  97. class DonutSwinEncoderOutput(OrderedDict):
  98. last_hidden_state = None
  99. hidden_states = None
  100. attentions = None
  101. reshaped_hidden_states = None
  102. def __init__(self, *args, **kwargs):
  103. super().__init__(*args, **kwargs)
  104. def __getitem__(self, k):
  105. if isinstance(k, str):
  106. inner_dict = dict(self.items())
  107. return inner_dict[k]
  108. else:
  109. return self.to_tuple()[k]
  110. def __setattr__(self, name, value):
  111. if name in self.keys() and value is not None:
  112. super().__setitem__(name, value)
  113. super().__setattr__(name, value)
  114. def __setitem__(self, key, value):
  115. super().__setitem__(key, value)
  116. super().__setattr__(key, value)
  117. def to_tuple(self):
  118. """
  119. Convert self to a tuple containing all the attributes/keys that are not `None`.
  120. """
  121. return tuple(self[k] for k in self.keys())
  122. @dataclass
  123. # Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->DonutSwin
  124. class DonutSwinModelOutput(OrderedDict):
  125. last_hidden_state = None
  126. pooler_output = None
  127. hidden_states = None
  128. attentions = None
  129. reshaped_hidden_states = None
  130. def __init__(self, *args, **kwargs):
  131. super().__init__(*args, **kwargs)
  132. def __getitem__(self, k):
  133. if isinstance(k, str):
  134. inner_dict = dict(self.items())
  135. return inner_dict[k]
  136. else:
  137. return self.to_tuple()[k]
  138. def __setattr__(self, name, value):
  139. if name in self.keys() and value is not None:
  140. super().__setitem__(name, value)
  141. super().__setattr__(name, value)
  142. def __setitem__(self, key, value):
  143. super().__setitem__(key, value)
  144. super().__setattr__(key, value)
  145. def to_tuple(self):
  146. """
  147. Convert self to a tuple containing all the attributes/keys that are not `None`.
  148. """
  149. return tuple(self[k] for k in self.keys())
  150. # Copied from transformers.models.swin.modeling_swin.window_partition
  151. def window_partition(input_feature, window_size):
  152. """
  153. Partitions the given input into windows.
  154. """
  155. batch_size, height, width, num_channels = input_feature.shape
  156. input_feature = input_feature.reshape(
  157. [
  158. batch_size,
  159. height // window_size,
  160. window_size,
  161. width // window_size,
  162. window_size,
  163. num_channels,
  164. ]
  165. )
  166. windows = input_feature.transpose([0, 1, 3, 2, 4, 5]).reshape(
  167. [-1, window_size, window_size, num_channels]
  168. )
  169. return windows
  170. # Copied from transformers.models.swin.modeling_swin.window_reverse
  171. def window_reverse(windows, window_size, height, width):
  172. """
  173. Merges windows to produce higher resolution features.
  174. """
  175. num_channels = windows.shape[-1]
  176. windows = windows.reshape(
  177. [
  178. -1,
  179. height // window_size,
  180. width // window_size,
  181. window_size,
  182. window_size,
  183. num_channels,
  184. ]
  185. )
  186. windows = windows.transpose([0, 1, 3, 2, 4, 5]).reshape(
  187. [-1, height, width, num_channels]
  188. )
  189. return windows
  190. # Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->DonutSwin
  191. class DonutSwinEmbeddings(nn.Layer):
  192. """
  193. Construct the patch and position embeddings. Optionally, also the mask token.
  194. """
  195. def __init__(self, config, use_mask_token=False):
  196. super().__init__()
  197. self.patch_embeddings = DonutSwinPatchEmbeddings(config)
  198. num_patches = self.patch_embeddings.num_patches
  199. self.patch_grid = self.patch_embeddings.grid_size
  200. if use_mask_token:
  201. self.mask_token = paddle.create_parameter(
  202. [1, 1, config.embed_dim], dtype="float32"
  203. )
  204. zeros_(self.mask_token)
  205. else:
  206. self.mask_token = None
  207. if config.use_absolute_embeddings:
  208. self.position_embeddings = paddle.create_parameter(
  209. [1, num_patches + 1, config.embed_dim], dtype="float32"
  210. )
  211. zeros_(self.position_embedding)
  212. else:
  213. self.position_embeddings = None
  214. self.norm = nn.LayerNorm(config.embed_dim)
  215. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  216. def forward(self, pixel_values, bool_masked_pos=None):
  217. embeddings, output_dimensions = self.patch_embeddings(pixel_values)
  218. embeddings = self.norm(embeddings)
  219. batch_size, seq_len, _ = embeddings.shape
  220. if bool_masked_pos is not None:
  221. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  222. mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  223. embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
  224. if self.position_embeddings is not None:
  225. embeddings = embeddings + self.position_embeddings
  226. embeddings = self.dropout(embeddings)
  227. return embeddings, output_dimensions
  228. class MyConv2d(nn.Conv2D):
  229. def __init__(
  230. self,
  231. in_channel,
  232. out_channels,
  233. kernel_size,
  234. stride=1,
  235. padding="SAME",
  236. dilation=1,
  237. groups=1,
  238. bias_attr=False,
  239. eps=1e-6,
  240. ):
  241. super().__init__(
  242. in_channel,
  243. out_channels,
  244. kernel_size,
  245. stride=stride,
  246. padding=padding,
  247. dilation=dilation,
  248. groups=groups,
  249. bias_attr=bias_attr,
  250. )
  251. self.weight = paddle.create_parameter(
  252. [out_channels, in_channel, kernel_size[0], kernel_size[1]], dtype="float32"
  253. )
  254. self.bias = paddle.create_parameter([out_channels], dtype="float32")
  255. ones_(self.weight)
  256. zeros_(self.bias)
  257. def forward(self, x):
  258. x = F.conv2d(
  259. x,
  260. self.weight,
  261. self.bias,
  262. self._stride,
  263. self._padding,
  264. self._dilation,
  265. self._groups,
  266. )
  267. return x
  268. # Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings
  269. class DonutSwinPatchEmbeddings(nn.Layer):
  270. """
  271. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  272. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  273. Transformer.
  274. """
  275. def __init__(self, config):
  276. super().__init__()
  277. image_size, patch_size = config.image_size, config.patch_size
  278. num_channels, hidden_size = config.num_channels, config.embed_dim
  279. image_size = (
  280. image_size
  281. if isinstance(image_size, collections.abc.Iterable)
  282. else (image_size, image_size)
  283. )
  284. patch_size = (
  285. patch_size
  286. if isinstance(patch_size, collections.abc.Iterable)
  287. else (patch_size, patch_size)
  288. )
  289. num_patches = (image_size[1] // patch_size[1]) * (
  290. image_size[0] // patch_size[0]
  291. )
  292. self.image_size = image_size
  293. self.patch_size = patch_size
  294. self.num_channels = num_channels
  295. self.num_patches = num_patches
  296. self.is_export = config.is_export
  297. self.grid_size = (
  298. image_size[0] // patch_size[0],
  299. image_size[1] // patch_size[1],
  300. )
  301. self.projection = nn.Conv2D(
  302. num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
  303. )
  304. def maybe_pad(self, pixel_values, height, width):
  305. if width % self.patch_size[1] != 0:
  306. pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
  307. if self.is_export:
  308. pad_values = paddle.to_tensor(pad_values, dtype="int32")
  309. pixel_values = nn.functional.pad(pixel_values, pad_values)
  310. if height % self.patch_size[0] != 0:
  311. pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
  312. if self.is_export:
  313. pad_values = paddle.to_tensor(pad_values, dtype="int32")
  314. pixel_values = nn.functional.pad(pixel_values, pad_values)
  315. return pixel_values
  316. def forward(self, pixel_values) -> Tuple[paddle.Tensor, Tuple[int]]:
  317. _, num_channels, height, width = pixel_values.shape
  318. if num_channels != self.num_channels:
  319. raise ValueError(
  320. "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
  321. )
  322. pixel_values = self.maybe_pad(pixel_values, height, width)
  323. embeddings = self.projection(pixel_values)
  324. _, _, height, width = embeddings.shape
  325. output_dimensions = (height, width)
  326. embeddings = embeddings.flatten(2).transpose([0, 2, 1])
  327. return embeddings, output_dimensions
  328. # Copied from transformers.models.swin.modeling_swin.SwinPatchMerging
  329. class DonutSwinPatchMerging(nn.Layer):
  330. """
  331. Patch Merging Layer.
  332. Args:
  333. input_resolution (`Tuple[int]`):
  334. Resolution of input feature.
  335. dim (`int`):
  336. Number of input channels.
  337. norm_layer (`nn.Layer`, *optional*, defaults to `nn.LayerNorm`):
  338. Normalization layer class.
  339. """
  340. def __init__(
  341. self,
  342. input_resolution: Tuple[int],
  343. dim: int,
  344. norm_layer: nn.Layer = nn.LayerNorm,
  345. is_export=False,
  346. ):
  347. super().__init__()
  348. self.input_resolution = input_resolution
  349. self.dim = dim
  350. self.reduction = nn.Linear(4 * dim, 2 * dim, bias_attr=False)
  351. self.norm = norm_layer(4 * dim)
  352. self.is_export = is_export
  353. def maybe_pad(self, input_feature, height, width):
  354. should_pad = (height % 2 == 1) or (width % 2 == 1)
  355. if should_pad:
  356. pad_values = (0, 0, 0, width % 2, 0, height % 2)
  357. if self.is_export:
  358. pad_values = paddle.to_tensor(pad_values, dtype="int32")
  359. input_feature = nn.functional.pad(input_feature, pad_values)
  360. return input_feature
  361. def forward(
  362. self, input_feature: paddle.Tensor, input_dimensions: Tuple[int, int]
  363. ) -> paddle.Tensor:
  364. height, width = input_dimensions
  365. batch_size, dim, num_channels = input_feature.shape
  366. input_feature = input_feature.reshape([batch_size, height, width, num_channels])
  367. input_feature = self.maybe_pad(input_feature, height, width)
  368. input_feature_0 = input_feature[:, 0::2, 0::2, :]
  369. input_feature_1 = input_feature[:, 1::2, 0::2, :]
  370. input_feature_2 = input_feature[:, 0::2, 1::2, :]
  371. input_feature_3 = input_feature[:, 1::2, 1::2, :]
  372. input_feature = paddle.concat(
  373. [input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1
  374. )
  375. input_feature = input_feature.reshape(
  376. [batch_size, -1, 4 * num_channels]
  377. ) # batch_size height/2*width/2 4*C
  378. input_feature = self.norm(input_feature)
  379. input_feature = self.reduction(input_feature)
  380. return input_feature
  381. # Copied from transformers.models.beit.modeling_beit.drop_path
  382. def drop_path(
  383. input: paddle.Tensor, drop_prob: float = 0.0, training: bool = False
  384. ) -> paddle.Tensor:
  385. if drop_prob == 0.0 or not training:
  386. return input
  387. keep_prob = 1 - drop_prob
  388. shape = (input.shape[0],) + (1,) * (
  389. input.ndim - 1
  390. ) # work with diff dim tensors, not just 2D ConvNets
  391. random_tensor = keep_prob + paddle.rand(
  392. shape,
  393. dtype=input.dtype,
  394. )
  395. random_tensor.floor_() # binarize
  396. output = input / keep_prob * random_tensor
  397. return output
  398. # Copied from transformers.models.swin.modeling_swin.SwinDropPath
  399. class DonutSwinDropPath(nn.Layer):
  400. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  401. def __init__(self, drop_prob: Optional[float] = None) -> None:
  402. super().__init__()
  403. self.drop_prob = drop_prob
  404. def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
  405. return drop_path(hidden_states, self.drop_prob, self.training)
  406. def extra_repr(self) -> str:
  407. return "p={}".format(self.drop_prob)
  408. class DonutSwinSelfAttention(nn.Layer):
  409. def __init__(self, config, dim, num_heads, window_size):
  410. super().__init__()
  411. if dim % num_heads != 0:
  412. raise ValueError(
  413. f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
  414. )
  415. self.num_attention_heads = num_heads
  416. self.attention_head_size = int(dim / num_heads)
  417. self.all_head_size = self.num_attention_heads * self.attention_head_size
  418. self.window_size = (
  419. window_size
  420. if isinstance(window_size, collections.abc.Iterable)
  421. else (window_size, window_size)
  422. )
  423. self.relative_position_bias_table = paddle.create_parameter(
  424. [(2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads],
  425. dtype="float32",
  426. )
  427. zeros_(self.relative_position_bias_table)
  428. # get pair-wise relative position index for each token inside the window
  429. coords_h = paddle.arange(self.window_size[0])
  430. coords_w = paddle.arange(self.window_size[1])
  431. coords = paddle.stack(paddle.meshgrid(coords_h, coords_w, indexing="ij"))
  432. coords_flatten = paddle.flatten(coords, 1)
  433. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
  434. relative_coords = relative_coords.transpose([1, 2, 0])
  435. relative_coords[:, :, 0] += self.window_size[0] - 1
  436. relative_coords[:, :, 1] += self.window_size[1] - 1
  437. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  438. relative_position_index = relative_coords.sum(-1)
  439. self.register_buffer("relative_position_index", relative_position_index)
  440. self.query = nn.Linear(
  441. self.all_head_size, self.all_head_size, bias_attr=config.qkv_bias
  442. )
  443. self.key = nn.Linear(
  444. self.all_head_size, self.all_head_size, bias_attr=config.qkv_bias
  445. )
  446. self.value = nn.Linear(
  447. self.all_head_size, self.all_head_size, bias_attr=config.qkv_bias
  448. )
  449. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  450. def transpose_for_scores(self, x):
  451. new_x_shape = x.shape[:-1] + [
  452. self.num_attention_heads,
  453. self.attention_head_size,
  454. ]
  455. x = x.reshape(new_x_shape)
  456. return x.transpose([0, 2, 1, 3])
  457. def forward(
  458. self,
  459. hidden_states: paddle.Tensor,
  460. attention_mask=None,
  461. head_mask=None,
  462. output_attentions=False,
  463. ) -> Tuple[paddle.Tensor]:
  464. batch_size, dim, num_channels = hidden_states.shape
  465. mixed_query_layer = self.query(hidden_states)
  466. key_layer = self.transpose_for_scores(self.key(hidden_states))
  467. value_layer = self.transpose_for_scores(self.value(hidden_states))
  468. query_layer = self.transpose_for_scores(mixed_query_layer)
  469. # Take the dot product between "query" and "key" to get the raw attention scores.
  470. attention_scores = paddle.matmul(query_layer, key_layer.transpose([0, 1, 3, 2]))
  471. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  472. relative_position_bias = self.relative_position_bias_table[
  473. self.relative_position_index.reshape([-1])
  474. ]
  475. relative_position_bias = relative_position_bias.reshape(
  476. [
  477. self.window_size[0] * self.window_size[1],
  478. self.window_size[0] * self.window_size[1],
  479. -1,
  480. ]
  481. )
  482. relative_position_bias = relative_position_bias.transpose([2, 0, 1])
  483. attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
  484. if attention_mask is not None:
  485. # Apply the attention mask is (precomputed for all layers in DonutSwinModel forward() function)
  486. mask_shape = attention_mask.shape[0]
  487. attention_scores = attention_scores.reshape(
  488. [
  489. batch_size // mask_shape,
  490. mask_shape,
  491. self.num_attention_heads,
  492. dim,
  493. dim,
  494. ]
  495. )
  496. attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(
  497. 0
  498. )
  499. attention_scores = attention_scores.reshape(
  500. [-1, self.num_attention_heads, dim, dim]
  501. )
  502. # Normalize the attention scores to probabilities.
  503. attention_probs = nn.functional.softmax(attention_scores, axis=-1)
  504. # This is actually dropping out entire tokens to attend to, which might
  505. # seem a bit unusual, but is taken from the original Transformer paper.
  506. attention_probs = self.dropout(attention_probs)
  507. # Mask heads if we want to
  508. if head_mask is not None:
  509. attention_probs = attention_probs * head_mask
  510. context_layer = paddle.matmul(attention_probs, value_layer)
  511. context_layer = context_layer.transpose([0, 2, 1, 3])
  512. new_context_layer_shape = tuple(context_layer.shape[:-2]) + (
  513. self.all_head_size,
  514. )
  515. context_layer = context_layer.reshape(new_context_layer_shape)
  516. outputs = (
  517. (context_layer, attention_probs) if output_attentions else (context_layer,)
  518. )
  519. return outputs
  520. # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput
  521. class DonutSwinSelfOutput(nn.Layer):
  522. def __init__(self, config, dim):
  523. super().__init__()
  524. self.dense = nn.Linear(dim, dim)
  525. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  526. def forward(
  527. self, hidden_states: paddle.Tensor, input_tensor: paddle.Tensor
  528. ) -> paddle.Tensor:
  529. hidden_states = self.dense(hidden_states)
  530. hidden_states = self.dropout(hidden_states)
  531. return hidden_states
  532. # Copied from transformers.models.swin.modeling_swin.SwinAttention with Swin->DonutSwin
  533. class DonutSwinAttention(nn.Layer):
  534. def __init__(self, config, dim, num_heads, window_size):
  535. super().__init__()
  536. self.self = DonutSwinSelfAttention(config, dim, num_heads, window_size)
  537. self.output = DonutSwinSelfOutput(config, dim)
  538. self.pruned_heads = set()
  539. def forward(
  540. self,
  541. hidden_states: paddle.Tensor,
  542. attention_mask=None,
  543. head_mask=None,
  544. output_attentions=False,
  545. ) -> Tuple[paddle.Tensor]:
  546. self_outputs = self.self(
  547. hidden_states, attention_mask, head_mask, output_attentions
  548. )
  549. attention_output = self.output(self_outputs[0], hidden_states)
  550. outputs = (attention_output,) + self_outputs[
  551. 1:
  552. ] # add attentions if we output them
  553. return outputs
  554. # Copied from transformers.models.swin.modeling_swin.SwinIntermediate
  555. class DonutSwinIntermediate(nn.Layer):
  556. def __init__(self, config, dim):
  557. super().__init__()
  558. self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
  559. self.intermediate_act_fn = F.gelu
  560. def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
  561. hidden_states = self.dense(hidden_states)
  562. hidden_states = self.intermediate_act_fn(hidden_states)
  563. return hidden_states
  564. # Copied from transformers.models.swin.modeling_swin.SwinOutput
  565. class DonutSwinOutput(nn.Layer):
  566. def __init__(self, config, dim):
  567. super().__init__()
  568. self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
  569. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  570. def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
  571. hidden_states = self.dense(hidden_states)
  572. hidden_states = self.dropout(hidden_states)
  573. return hidden_states
  574. # Copied from transformers.models.swin.modeling_swin.SwinLayer with Swin->DonutSwin
  575. class DonutSwinLayer(nn.Layer):
  576. def __init__(self, config, dim, input_resolution, num_heads, shift_size=0):
  577. super().__init__()
  578. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  579. self.shift_size = shift_size
  580. self.window_size = config.window_size
  581. self.input_resolution = input_resolution
  582. self.layernorm_before = nn.LayerNorm(dim, epsilon=config.layer_norm_eps)
  583. self.attention = DonutSwinAttention(
  584. config, dim, num_heads, window_size=self.window_size
  585. )
  586. self.drop_path = (
  587. DonutSwinDropPath(config.drop_path_rate)
  588. if config.drop_path_rate > 0.0
  589. else nn.Identity()
  590. )
  591. self.layernorm_after = nn.LayerNorm(dim, epsilon=config.layer_norm_eps)
  592. self.intermediate = DonutSwinIntermediate(config, dim)
  593. self.output = DonutSwinOutput(config, dim)
  594. self.is_export = config.is_export
  595. def set_shift_and_window_size(self, input_resolution):
  596. if min(input_resolution) <= self.window_size:
  597. # if window size is larger than input resolution, we don't partition windows
  598. self.shift_size = 0
  599. self.window_size = min(input_resolution)
  600. def get_attn_mask_export(self, height, width, dtype):
  601. attn_mask = None
  602. height_slices = (
  603. slice(0, -self.window_size),
  604. slice(-self.window_size, -self.shift_size),
  605. slice(-self.shift_size, None),
  606. )
  607. width_slices = (
  608. slice(0, -self.window_size),
  609. slice(-self.window_size, -self.shift_size),
  610. slice(-self.shift_size, None),
  611. )
  612. img_mask = paddle.zeros((1, height, width, 1), dtype=dtype)
  613. count = 0
  614. for height_slice in height_slices:
  615. for width_slice in width_slices:
  616. if self.shift_size > 0:
  617. img_mask[:, height_slice, width_slice, :] = count
  618. count += 1
  619. if paddle.to_tensor(self.shift_size > 0).cast(paddle.bool):
  620. # calculate attention mask for SW-MSA
  621. mask_windows = window_partition(img_mask, self.window_size)
  622. mask_windows = mask_windows.reshape(
  623. [-1, self.window_size * self.window_size]
  624. )
  625. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  626. attn_mask = attn_mask.masked_fill(
  627. attn_mask != 0, float(-100.0)
  628. ).masked_fill(attn_mask == 0, float(0.0))
  629. return attn_mask
  630. def get_attn_mask(self, height, width, dtype):
  631. if self.shift_size > 0:
  632. # calculate attention mask for SW-MSA
  633. img_mask = paddle.zeros((1, height, width, 1), dtype=dtype)
  634. height_slices = (
  635. slice(0, -self.window_size),
  636. slice(-self.window_size, -self.shift_size),
  637. slice(-self.shift_size, None),
  638. )
  639. width_slices = (
  640. slice(0, -self.window_size),
  641. slice(-self.window_size, -self.shift_size),
  642. slice(-self.shift_size, None),
  643. )
  644. count = 0
  645. for height_slice in height_slices:
  646. for width_slice in width_slices:
  647. img_mask[:, height_slice, width_slice, :] = count
  648. count += 1
  649. mask_windows = window_partition(img_mask, self.window_size)
  650. mask_windows = mask_windows.reshape(
  651. [-1, self.window_size * self.window_size]
  652. )
  653. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  654. attn_mask = attn_mask.masked_fill(
  655. attn_mask != 0, float(-100.0)
  656. ).masked_fill(attn_mask == 0, float(0.0))
  657. else:
  658. attn_mask = None
  659. return attn_mask
  660. def maybe_pad(self, hidden_states, height, width):
  661. pad_right = (self.window_size - width % self.window_size) % self.window_size
  662. pad_bottom = (self.window_size - height % self.window_size) % self.window_size
  663. pad_values = (0, 0, 0, pad_bottom, 0, pad_right, 0, 0)
  664. hidden_states = nn.functional.pad(hidden_states, pad_values)
  665. return hidden_states, pad_values
  666. def forward(
  667. self,
  668. hidden_states: paddle.Tensor,
  669. input_dimensions: Tuple[int, int],
  670. head_mask=None,
  671. output_attentions=False,
  672. always_partition=False,
  673. ) -> Tuple[paddle.Tensor, paddle.Tensor]:
  674. if not always_partition:
  675. self.set_shift_and_window_size(input_dimensions)
  676. else:
  677. pass
  678. height, width = input_dimensions
  679. batch_size, _, channels = hidden_states.shape
  680. shortcut = hidden_states
  681. hidden_states = self.layernorm_before(hidden_states)
  682. hidden_states = hidden_states.reshape([batch_size, height, width, channels])
  683. # pad hidden_states to multiples of window size
  684. hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
  685. _, height_pad, width_pad, _ = hidden_states.shape
  686. # cyclic shift
  687. if self.shift_size > 0:
  688. shift_value = (-self.shift_size, -self.shift_size)
  689. if self.is_export:
  690. shift_value = paddle.to_tensor(shift_value, dtype="int32")
  691. shifted_hidden_states = paddle.roll(
  692. hidden_states, shifts=shift_value, axis=(1, 2)
  693. )
  694. else:
  695. shifted_hidden_states = hidden_states
  696. # partition windows
  697. hidden_states_windows = window_partition(
  698. shifted_hidden_states, self.window_size
  699. )
  700. hidden_states_windows = hidden_states_windows.reshape(
  701. [-1, self.window_size * self.window_size, channels]
  702. )
  703. attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
  704. attention_outputs = self.attention(
  705. hidden_states_windows,
  706. attn_mask,
  707. head_mask,
  708. output_attentions=output_attentions,
  709. )
  710. attention_output = attention_outputs[0]
  711. attention_windows = attention_output.reshape(
  712. [-1, self.window_size, self.window_size, channels]
  713. )
  714. shifted_windows = window_reverse(
  715. attention_windows, self.window_size, height_pad, width_pad
  716. )
  717. # reverse cyclic shift
  718. if self.shift_size > 0:
  719. shift_value = (self.shift_size, self.shift_size)
  720. if self.is_export:
  721. shift_value = paddle.to_tensor(shift_value, dtype="int32")
  722. attention_windows = paddle.roll(
  723. shifted_windows, shifts=shift_value, axis=(1, 2)
  724. )
  725. else:
  726. attention_windows = shifted_windows
  727. was_padded = pad_values[3] > 0 or pad_values[5] > 0
  728. if was_padded:
  729. attention_windows = attention_windows[:, :height, :width, :].contiguous()
  730. attention_windows = attention_windows.reshape(
  731. [batch_size, height * width, channels]
  732. )
  733. hidden_states = shortcut + self.drop_path(attention_windows)
  734. layer_output = self.layernorm_after(hidden_states)
  735. layer_output = self.intermediate(layer_output)
  736. layer_output = hidden_states + self.output(layer_output)
  737. layer_outputs = (
  738. (layer_output, attention_outputs[1])
  739. if output_attentions
  740. else (layer_output,)
  741. )
  742. return layer_outputs
  743. # Copied from transformers.models.swin.modeling_swin.SwinStage with Swin->DonutSwin
  744. class DonutSwinStage(nn.Layer):
  745. def __init__(
  746. self, config, dim, input_resolution, depth, num_heads, drop_path, downsample
  747. ):
  748. super().__init__()
  749. self.config = config
  750. self.dim = dim
  751. self.blocks = nn.LayerList(
  752. [
  753. DonutSwinLayer(
  754. config=config,
  755. dim=dim,
  756. input_resolution=input_resolution,
  757. num_heads=num_heads,
  758. shift_size=0 if (i % 2 == 0) else config.window_size // 2,
  759. )
  760. for i in range(depth)
  761. ]
  762. )
  763. self.is_export = config.is_export
  764. # patch merging layer
  765. if downsample is not None:
  766. self.downsample = downsample(
  767. input_resolution,
  768. dim=dim,
  769. norm_layer=nn.LayerNorm,
  770. is_export=self.is_export,
  771. )
  772. else:
  773. self.downsample = None
  774. self.pointing = False
  775. def forward(
  776. self,
  777. hidden_states: paddle.Tensor,
  778. input_dimensions: Tuple[int, int],
  779. head_mask=None,
  780. output_attentions=False,
  781. always_partition=False,
  782. ) -> Tuple[paddle.Tensor]:
  783. height, width = input_dimensions
  784. for i, layer_module in enumerate(self.blocks):
  785. layer_head_mask = head_mask[i] if head_mask is not None else None
  786. layer_outputs = layer_module(
  787. hidden_states,
  788. input_dimensions,
  789. layer_head_mask,
  790. output_attentions,
  791. always_partition,
  792. )
  793. hidden_states = layer_outputs[0]
  794. hidden_states_before_downsampling = hidden_states
  795. if self.downsample is not None:
  796. height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
  797. output_dimensions = (height, width, height_downsampled, width_downsampled)
  798. hidden_states = self.downsample(
  799. hidden_states_before_downsampling, input_dimensions
  800. )
  801. else:
  802. output_dimensions = (height, width, height, width)
  803. stage_outputs = (
  804. hidden_states,
  805. hidden_states_before_downsampling,
  806. output_dimensions,
  807. )
  808. if output_attentions:
  809. stage_outputs += layer_outputs[1:]
  810. return stage_outputs
  811. # Copied from transformers.models.swin.modeling_swin.SwinEncoder with Swin->DonutSwin
  812. class DonutSwinEncoder(nn.Layer):
  813. def __init__(self, config, grid_size):
  814. super().__init__()
  815. self.num_layers = len(config.depths)
  816. self.config = config
  817. dpr = [
  818. x.item()
  819. for x in paddle.linspace(0, config.drop_path_rate, sum(config.depths))
  820. ]
  821. self.layers = nn.LayerList(
  822. [
  823. DonutSwinStage(
  824. config=config,
  825. dim=int(config.embed_dim * 2**i_layer),
  826. input_resolution=(
  827. grid_size[0] // (2**i_layer),
  828. grid_size[1] // (2**i_layer),
  829. ),
  830. depth=config.depths[i_layer],
  831. num_heads=config.num_heads[i_layer],
  832. drop_path=dpr[
  833. sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])
  834. ],
  835. downsample=(
  836. DonutSwinPatchMerging
  837. if (i_layer < self.num_layers - 1)
  838. else None
  839. ),
  840. )
  841. for i_layer in range(self.num_layers)
  842. ]
  843. )
  844. self.gradient_checkpointing = False
  845. def forward(
  846. self,
  847. hidden_states: paddle.Tensor,
  848. input_dimensions: Tuple[int, int],
  849. head_mask=None,
  850. output_attentions=False,
  851. output_hidden_states=False,
  852. output_hidden_states_before_downsampling=False,
  853. always_partition=False,
  854. return_dict=True,
  855. ):
  856. all_hidden_states = () if output_hidden_states else None
  857. all_reshaped_hidden_states = () if output_hidden_states else None
  858. all_self_attentions = () if output_attentions else None
  859. if output_hidden_states:
  860. batch_size, _, hidden_size = hidden_states.shape
  861. reshaped_hidden_state = hidden_states.view(
  862. batch_size, *input_dimensions, hidden_size
  863. )
  864. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  865. all_hidden_states += (hidden_states,)
  866. all_reshaped_hidden_states += (reshaped_hidden_state,)
  867. for i, layer_module in enumerate(self.layers):
  868. layer_head_mask = head_mask[i] if head_mask is not None else None
  869. if self.gradient_checkpointing and self.training:
  870. layer_outputs = self._gradient_checkpointing_func(
  871. layer_module.__call__,
  872. hidden_states,
  873. input_dimensions,
  874. layer_head_mask,
  875. output_attentions,
  876. always_partition,
  877. )
  878. else:
  879. layer_outputs = layer_module(
  880. hidden_states,
  881. input_dimensions,
  882. layer_head_mask,
  883. output_attentions,
  884. always_partition,
  885. )
  886. hidden_states = layer_outputs[0]
  887. hidden_states_before_downsampling = layer_outputs[1]
  888. output_dimensions = layer_outputs[2]
  889. input_dimensions = (output_dimensions[-2], output_dimensions[-1])
  890. if output_hidden_states and output_hidden_states_before_downsampling:
  891. batch_size, _, hidden_size = hidden_states_before_downsampling.shape
  892. reshaped_hidden_state = hidden_states_before_downsampling.reshape(
  893. [
  894. batch_size,
  895. *(output_dimensions[0], output_dimensions[1]),
  896. hidden_size,
  897. ]
  898. )
  899. reshaped_hidden_state = reshaped_hidden_state.transpose([0, 3, 1, 2])
  900. all_hidden_states += (hidden_states_before_downsampling,)
  901. all_reshaped_hidden_states += (reshaped_hidden_state,)
  902. elif output_hidden_states and not output_hidden_states_before_downsampling:
  903. batch_size, _, hidden_size = hidden_states.shape
  904. reshaped_hidden_state = hidden_states.reshape(
  905. [batch_size, *input_dimensions, hidden_size]
  906. )
  907. reshaped_hidden_state = reshaped_hidden_state.transpose([0, 3, 1, 2])
  908. all_hidden_states += (hidden_states,)
  909. all_reshaped_hidden_states += (reshaped_hidden_state,)
  910. if output_attentions:
  911. all_self_attentions += layer_outputs[3:]
  912. if not return_dict:
  913. return tuple(
  914. v
  915. for v in [hidden_states, all_hidden_states, all_self_attentions]
  916. if v is not None
  917. )
  918. return DonutSwinEncoderOutput(
  919. last_hidden_state=hidden_states,
  920. hidden_states=all_hidden_states,
  921. attentions=all_self_attentions,
  922. reshaped_hidden_states=all_reshaped_hidden_states,
  923. )
  924. class DonutSwinPreTrainedModel(nn.Layer):
  925. """
  926. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  927. models.
  928. """
  929. config_class = DonutSwinConfig
  930. base_model_prefix = "swin"
  931. main_input_name = "pixel_values"
  932. supports_gradient_checkpointing = True
  933. def _init_weights(self, module):
  934. """Initialize the weights"""
  935. if isinstance(module, (nn.Linear, nn.Conv2D)):
  936. normal_ = Normal(mean=0.0, std=self.config.initializer_range)
  937. normal_(module.weight)
  938. if module.bias is not None:
  939. zeros_(module.bias)
  940. elif isinstance(module, nn.LayerNorm):
  941. zeros_(module.bias)
  942. ones_(module.weight)
  943. def _initialize_weights(self, module):
  944. """
  945. Initialize the weights if they are not already initialized.
  946. """
  947. if getattr(module, "_is_hf_initialized", False):
  948. return
  949. self._init_weights(module)
  950. def post_init(self):
  951. self.apply(self._initialize_weights)
  952. def get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False):
  953. if head_mask is not None:
  954. head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
  955. if is_attention_chunked is True:
  956. head_mask = head_mask.unsqueeze(-1)
  957. else:
  958. head_mask = [None] * num_hidden_layers
  959. return head_mask
  960. class DonutSwinModel(DonutSwinPreTrainedModel):
  961. def __init__(
  962. self,
  963. in_channels=3,
  964. hidden_size=1024,
  965. num_layers=4,
  966. num_heads=[4, 8, 16, 32],
  967. add_pooling_layer=True,
  968. use_mask_token=False,
  969. is_export=False,
  970. ):
  971. super().__init__()
  972. donut_swin_config = {
  973. "return_dict": True,
  974. "output_hidden_states": False,
  975. "output_attentions": False,
  976. "use_bfloat16": False,
  977. "tf_legacy_loss": False,
  978. "pruned_heads": {},
  979. "tie_word_embeddings": True,
  980. "chunk_size_feed_forward": 0,
  981. "is_encoder_decoder": False,
  982. "is_decoder": False,
  983. "cross_attention_hidden_size": None,
  984. "add_cross_attention": False,
  985. "tie_encoder_decoder": False,
  986. "max_length": 20,
  987. "min_length": 0,
  988. "do_sample": False,
  989. "early_stopping": False,
  990. "num_beams": 1,
  991. "num_beam_groups": 1,
  992. "diversity_penalty": 0.0,
  993. "temperature": 1.0,
  994. "top_k": 50,
  995. "top_p": 1.0,
  996. "typical_p": 1.0,
  997. "repetition_penalty": 1.0,
  998. "length_penalty": 1.0,
  999. "no_repeat_ngram_size": 0,
  1000. "encoder_no_repeat_ngram_size": 0,
  1001. "bad_words_ids": None,
  1002. "num_return_sequences": 1,
  1003. "output_scores": False,
  1004. "return_dict_in_generate": False,
  1005. "forced_bos_token_id": None,
  1006. "forced_eos_token_id": None,
  1007. "remove_invalid_values": False,
  1008. "exponential_decay_length_penalty": None,
  1009. "suppress_tokens": None,
  1010. "begin_suppress_tokens": None,
  1011. "architectures": None,
  1012. "finetuning_task": None,
  1013. "id2label": {0: "LABEL_0", 1: "LABEL_1"},
  1014. "label2id": {"LABEL_0": 0, "LABEL_1": 1},
  1015. "tokenizer_class": None,
  1016. "prefix": None,
  1017. "bos_token_id": None,
  1018. "pad_token_id": None,
  1019. "eos_token_id": None,
  1020. "sep_token_id": None,
  1021. "decoder_start_token_id": None,
  1022. "task_specific_params": None,
  1023. "problem_type": None,
  1024. "_name_or_path": "",
  1025. "_commit_hash": None,
  1026. "_attn_implementation_internal": None,
  1027. "transformers_version": None,
  1028. "hidden_size": hidden_size,
  1029. "num_layers": num_layers,
  1030. "path_norm": True,
  1031. "use_2d_embeddings": False,
  1032. "image_size": [420, 420],
  1033. "patch_size": 4,
  1034. "num_channels": in_channels,
  1035. "embed_dim": 128,
  1036. "depths": [2, 2, 14, 2],
  1037. "num_heads": num_heads,
  1038. "window_size": 5,
  1039. "mlp_ratio": 4.0,
  1040. "qkv_bias": True,
  1041. "hidden_dropout_prob": 0.0,
  1042. "attention_probs_dropout_prob": 0.0,
  1043. "drop_path_rate": 0.1,
  1044. "hidden_act": "gelu",
  1045. "use_absolute_embeddings": False,
  1046. "layer_norm_eps": 1e-05,
  1047. "initializer_range": 0.02,
  1048. "is_export": is_export,
  1049. }
  1050. config = DonutSwinConfig(**donut_swin_config)
  1051. self.config = config
  1052. self.num_layers = len(config.depths)
  1053. self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
  1054. self.embeddings = DonutSwinEmbeddings(config, use_mask_token=use_mask_token)
  1055. self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid)
  1056. self.pooler = nn.AdaptiveAvgPool1D(1) if add_pooling_layer else None
  1057. self.out_channels = hidden_size
  1058. self.post_init()
  1059. def get_input_embeddings(self):
  1060. return self.embeddings.patch_embeddings
  1061. def forward(
  1062. self,
  1063. input_data=None,
  1064. bool_masked_pos=None,
  1065. head_mask=None,
  1066. output_attentions=None,
  1067. output_hidden_states=None,
  1068. return_dict=None,
  1069. ) -> Union[Tuple, DonutSwinModelOutput]:
  1070. r"""
  1071. bool_masked_pos (`paddle.BoolTensor` of shape `(batch_size, num_patches)`):
  1072. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  1073. """
  1074. if self.training:
  1075. pixel_values, label, attention_mask = input_data
  1076. else:
  1077. if isinstance(input_data, list):
  1078. pixel_values = input_data[0]
  1079. else:
  1080. pixel_values = input_data
  1081. output_attentions = (
  1082. output_attentions
  1083. if output_attentions is not None
  1084. else self.config.output_attentions
  1085. )
  1086. output_hidden_states = (
  1087. output_hidden_states
  1088. if output_hidden_states is not None
  1089. else self.config.output_hidden_states
  1090. )
  1091. return_dict = (
  1092. return_dict if return_dict is not None else self.config.return_dict
  1093. )
  1094. if pixel_values is None:
  1095. raise ValueError("You have to specify pixel_values")
  1096. num_channels = pixel_values.shape[1]
  1097. if num_channels == 1:
  1098. pixel_values = paddle.repeat_interleave(pixel_values, repeats=3, axis=1)
  1099. head_mask = self.get_head_mask(head_mask, len(self.config.depths))
  1100. embedding_output, input_dimensions = self.embeddings(
  1101. pixel_values, bool_masked_pos=bool_masked_pos
  1102. )
  1103. encoder_outputs = self.encoder(
  1104. embedding_output,
  1105. input_dimensions,
  1106. head_mask=head_mask,
  1107. output_attentions=output_attentions,
  1108. output_hidden_states=output_hidden_states,
  1109. return_dict=return_dict,
  1110. )
  1111. sequence_output = encoder_outputs[0]
  1112. pooled_output = None
  1113. if self.pooler is not None:
  1114. pooled_output = self.pooler(sequence_output.transpose([0, 2, 1]))
  1115. pooled_output = paddle.flatten(pooled_output, 1)
  1116. if not return_dict:
  1117. output = (sequence_output, pooled_output) + encoder_outputs[1:]
  1118. return output
  1119. donut_swin_output = DonutSwinModelOutput(
  1120. last_hidden_state=sequence_output,
  1121. pooler_output=pooled_output,
  1122. hidden_states=encoder_outputs.hidden_states,
  1123. attentions=encoder_outputs.attentions,
  1124. reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
  1125. )
  1126. if self.training:
  1127. return donut_swin_output, label, attention_mask
  1128. else:
  1129. return donut_swin_output