rec_vary_vit.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from functools import partial
  16. from typing import Optional, Tuple, Type
  17. import numpy as np
  18. import paddle
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from paddle.nn.initializer import (
  22. Constant,
  23. KaimingUniform,
  24. Normal,
  25. TruncatedNormal,
  26. XavierUniform,
  27. )
  28. from ppocr.modeling.backbones.rec_donut_swin import DonutSwinModelOutput
  29. zeros_ = Constant(value=0.0)
  30. ones_ = Constant(value=1.0)
  31. kaiming_normal_ = KaimingUniform(nonlinearity="relu")
  32. trunc_normal_ = TruncatedNormal(std=0.02)
  33. xavier_uniform_ = XavierUniform()
  34. class MLPBlock(nn.Layer):
  35. def __init__(
  36. self,
  37. embedding_dim: int,
  38. mlp_dim: int,
  39. act: Type[nn.Layer] = nn.GELU,
  40. ) -> None:
  41. super().__init__()
  42. self.lin1 = nn.Linear(embedding_dim, mlp_dim)
  43. self.lin2 = nn.Linear(mlp_dim, embedding_dim)
  44. self.act = act()
  45. def forward(self, x):
  46. return self.lin2(self.act(self.lin1(x)))
  47. # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
  48. # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
  49. class LayerNorm2d(nn.Layer):
  50. def __init__(self, num_channels: int, epsilon: float = 1e-6) -> None:
  51. super().__init__()
  52. self.weight = paddle.create_parameter([num_channels], dtype="float32")
  53. ones_(self.weight)
  54. self.bias = paddle.create_parameter([num_channels], dtype="float32")
  55. zeros_(self.bias)
  56. self.epsilon = epsilon
  57. def forward(self, x):
  58. u = x.mean(1, keepdim=True)
  59. s = (x - u).pow(2).mean(1, keepdim=True)
  60. x = (x - u) / paddle.sqrt(s + self.epsilon)
  61. x = self.weight[:, None, None] * x + self.bias[:, None, None]
  62. return x
  63. # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
  64. class ImageEncoderViT(nn.Layer):
  65. def __init__(
  66. self,
  67. img_size: int = 1024,
  68. patch_size: int = 16,
  69. in_chans: int = 3,
  70. embed_dim: int = 768,
  71. depth: int = 12,
  72. num_heads: int = 12,
  73. mlp_ratio: float = 4.0,
  74. out_chans: int = 256,
  75. qkv_bias: bool = True,
  76. norm_layer: Type[nn.Layer] = nn.LayerNorm,
  77. act_layer: Type[nn.Layer] = nn.GELU,
  78. use_abs_pos: bool = True,
  79. use_rel_pos: bool = False,
  80. rel_pos_zero_init: bool = True,
  81. window_size: int = 0,
  82. global_attn_indexes: Tuple[int, ...] = (),
  83. is_formula: bool = False,
  84. ) -> None:
  85. """
  86. Args:
  87. img_size (int): Input image size.
  88. patch_size (int): Patch size.
  89. in_chans (int): Number of input image channels.
  90. embed_dim (int): Patch embedding dimension.
  91. depth (int): Depth of ViT.
  92. num_heads (int): Number of attention heads in each ViT block.
  93. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  94. qkv_bias (bool): If True, add a learnable bias to query, key, value.
  95. norm_layer (nn.Layer): Normalization layer.
  96. act_layer (nn.Layer): Activation layer.
  97. use_abs_pos (bool): If True, use absolute positional embeddings.
  98. use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
  99. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  100. window_size (int): Window size for window attention blocks.
  101. global_attn_indexes (list): Indexes for blocks using global attention.
  102. """
  103. super().__init__()
  104. self.img_size = img_size
  105. self.patch_embed = PatchEmbed(
  106. kernel_size=(patch_size, patch_size),
  107. stride=(patch_size, patch_size),
  108. in_chans=in_chans,
  109. embed_dim=embed_dim,
  110. )
  111. self.pos_embed = None
  112. if use_abs_pos:
  113. # Initialize absolute positional embedding with pretrain image size.
  114. self.pos_embed = paddle.create_parameter(
  115. shape=(1, img_size // patch_size, img_size // patch_size, embed_dim),
  116. dtype="float32",
  117. )
  118. zeros_(self.pos_embed)
  119. self.blocks = nn.LayerList()
  120. for i in range(depth):
  121. block = Block(
  122. dim=embed_dim,
  123. num_heads=num_heads,
  124. mlp_ratio=mlp_ratio,
  125. qkv_bias=qkv_bias,
  126. norm_layer=norm_layer,
  127. act_layer=act_layer,
  128. use_rel_pos=use_rel_pos,
  129. rel_pos_zero_init=rel_pos_zero_init,
  130. window_size=window_size if i not in global_attn_indexes else 0,
  131. input_size=(img_size // patch_size, img_size // patch_size),
  132. )
  133. self.blocks.append(block)
  134. self.neck = nn.Sequential(
  135. nn.Conv2D(
  136. embed_dim,
  137. out_chans,
  138. kernel_size=1,
  139. bias_attr=False,
  140. ),
  141. LayerNorm2d(out_chans),
  142. nn.Conv2D(
  143. out_chans,
  144. out_chans,
  145. kernel_size=3,
  146. padding=1,
  147. bias_attr=False,
  148. ),
  149. LayerNorm2d(out_chans),
  150. )
  151. self.net_2 = nn.Conv2D(
  152. 256, 512, kernel_size=3, stride=2, padding=1, bias_attr=False
  153. )
  154. self.net_3 = nn.Conv2D(
  155. 512, 1024, kernel_size=3, stride=2, padding=1, bias_attr=False
  156. )
  157. self.is_formula = is_formula
  158. def forward(self, x):
  159. x = self.patch_embed(x)
  160. if self.pos_embed is not None:
  161. x = x + self.pos_embed
  162. for blk in self.blocks:
  163. x = blk(x)
  164. x = self.neck(x.transpose([0, 3, 1, 2]))
  165. x = self.net_2(x)
  166. if self.is_formula:
  167. x = self.net_3(x)
  168. return x
  169. class Block(nn.Layer):
  170. """Transformer blocks with support of window attention and residual propagation blocks"""
  171. def __init__(
  172. self,
  173. dim: int,
  174. num_heads: int,
  175. mlp_ratio: float = 4.0,
  176. qkv_bias: bool = True,
  177. norm_layer: Type[nn.Layer] = nn.LayerNorm,
  178. act_layer: Type[nn.Layer] = nn.GELU,
  179. use_rel_pos: bool = False,
  180. rel_pos_zero_init: bool = True,
  181. window_size: int = 0,
  182. input_size: Optional[Tuple[int, int]] = None,
  183. ) -> None:
  184. """
  185. Args:
  186. dim (int): Number of input channels.
  187. num_heads (int): Number of attention heads in each ViT block.
  188. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  189. qkv_bias (bool): If True, add a learnable bias to query, key, value.
  190. norm_layer (nn.Layer): Normalization layer.
  191. act_layer (nn.Layer): Activation layer.
  192. use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
  193. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  194. window_size (int): Window size for window attention blocks. If it equals 0, then
  195. use global attention.
  196. input_size (tuple(int, int) or None): Input resolution for calculating the relative
  197. positional parameter size.
  198. """
  199. super().__init__()
  200. self.norm1 = norm_layer(dim)
  201. self.attn = Attention(
  202. dim,
  203. num_heads=num_heads,
  204. qkv_bias=qkv_bias,
  205. use_rel_pos=use_rel_pos,
  206. rel_pos_zero_init=rel_pos_zero_init,
  207. input_size=input_size if window_size == 0 else (window_size, window_size),
  208. )
  209. self.norm2 = norm_layer(dim)
  210. self.mlp = MLPBlock(
  211. embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
  212. )
  213. self.window_size = window_size
  214. def forward(self, x):
  215. shortcut = x
  216. x = self.norm1(x)
  217. # Window partition
  218. if self.window_size > 0:
  219. H, W = x.shape[1], x.shape[2]
  220. x, pad_hw = window_partition(x, self.window_size)
  221. x = self.attn(x)
  222. # Reverse window partition
  223. if self.window_size > 0:
  224. x = window_unpartition(x, self.window_size, pad_hw, (H, W))
  225. x = shortcut + x
  226. x = x + self.mlp(self.norm2(x))
  227. return x
  228. class Attention(nn.Layer):
  229. """Multi-head Attention block with relative position embeddings."""
  230. def __init__(
  231. self,
  232. dim: int,
  233. num_heads: int = 8,
  234. qkv_bias: bool = True,
  235. use_rel_pos: bool = False,
  236. rel_pos_zero_init: bool = True,
  237. input_size: Optional[Tuple[int, int]] = None,
  238. ) -> None:
  239. """
  240. Args:
  241. dim (int): Number of input channels.
  242. num_heads (int): Number of attention heads.
  243. qkv_bias (bool): If True, add a learnable bias to query, key, value.
  244. rel_pos (bool): If True, add relative positional embeddings to the attention map.
  245. rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
  246. input_size (tuple(int, int) or None): Input resolution for calculating the relative
  247. positional parameter size.
  248. """
  249. super().__init__()
  250. self.num_heads = num_heads
  251. head_dim = dim // num_heads
  252. self.scale = head_dim**-0.5
  253. self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
  254. self.proj = nn.Linear(dim, dim)
  255. self.use_rel_pos = use_rel_pos
  256. if self.use_rel_pos:
  257. assert (
  258. input_size is not None
  259. ), "Input size must be provided if using relative positional encoding."
  260. # initialize relative positional embeddings
  261. self.rel_pos_h = paddle.create_parameter(
  262. [2 * input_size[0] - 1, head_dim], dtype="float32"
  263. )
  264. zeros_(self.rel_pos_h)
  265. self.rel_pos_w = paddle.create_parameter(
  266. [2 * input_size[1] - 1, head_dim], dtype="float32"
  267. )
  268. zeros_(self.rel_pos_w)
  269. def forward(self, x):
  270. B, H, W, _ = x.shape
  271. qkv = (
  272. self.qkv(x)
  273. .reshape([B, H * W, 3, self.num_heads, -1])
  274. .transpose([2, 0, 3, 1, 4])
  275. )
  276. q, k, v = qkv.reshape([3, B * self.num_heads, H * W, -1]).unbind(0)
  277. attn = (q * self.scale) @ k.transpose([0, 2, 1])
  278. if self.use_rel_pos:
  279. attn = add_decomposed_rel_pos(
  280. attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
  281. )
  282. attn = F.softmax(attn, axis=-1)
  283. x = (
  284. (attn @ v)
  285. .reshape([B, self.num_heads, H, W, -1])
  286. .transpose([0, 2, 3, 1, 4])
  287. .reshape([B, H, W, -1])
  288. )
  289. x = self.proj(x)
  290. return x
  291. def window_partition(x, window_size: int):
  292. """
  293. Partition into non-overlapping windows with padding if needed.
  294. Args:
  295. x (tensor): input tokens with [B, H, W, C].
  296. window_size (int): window size.
  297. Returns:
  298. windows: windows after partition with [B * num_windows, window_size, window_size, C].
  299. (Hp, Wp): padded height and width before partition
  300. """
  301. B, H, W, C = x.shape
  302. pad_h = (window_size - H % window_size) % window_size
  303. pad_w = (window_size - W % window_size) % window_size
  304. if pad_h > 0 or pad_w > 0:
  305. x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h, 0, 0))
  306. Hp, Wp = H + pad_h, W + pad_w
  307. x = x.reshape(
  308. [B, Hp // window_size, window_size, Wp // window_size, window_size, C]
  309. )
  310. windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape([-1, window_size, window_size, C])
  311. return windows, (Hp, Wp)
  312. def window_unpartition(
  313. windows, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
  314. ):
  315. """
  316. Window unpartition into original sequences and removing padding.
  317. Args:
  318. windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
  319. window_size (int): window size.
  320. pad_hw (Tuple): padded height and width (Hp, Wp).
  321. hw (Tuple): original height and width (H, W) before padding.
  322. Returns:
  323. x: unpartitioned sequences with [B, H, W, C].
  324. """
  325. Hp, Wp = pad_hw
  326. H, W = hw
  327. B = windows.shape[0] // (Hp * Wp // window_size // window_size)
  328. x = windows.reshape(
  329. [B, Hp // window_size, Wp // window_size, window_size, window_size, -1]
  330. )
  331. x = x.transpose([0, 1, 3, 2, 4, 5]).contiguous().reshape([B, Hp, Wp, -1])
  332. if Hp > H or Wp > W:
  333. x = x[:, :H, :W, :].contiguous()
  334. return x
  335. def get_rel_pos(q_size: int, k_size: int, rel_pos):
  336. """
  337. Get relative positional embeddings according to the relative positions of
  338. query and key sizes.
  339. Args:
  340. q_size (int): size of query q.
  341. k_size (int): size of key k.
  342. rel_pos (Tensor): relative position embeddings (L, C).
  343. Returns:
  344. Extracted positional embeddings according to relative positions.
  345. """
  346. max_rel_dist = int(2 * max(q_size, k_size) - 1)
  347. # Interpolate rel pos if needed.
  348. if rel_pos.shape[0] != max_rel_dist:
  349. # Interpolate rel pos.
  350. rel_pos_resized = F.interpolate(
  351. rel_pos.reshape(1, rel_pos.shape[0], -1).transpose(0, 2, 1),
  352. size=max_rel_dist,
  353. mode="linear",
  354. )
  355. rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).transpose(1, 0)
  356. else:
  357. rel_pos_resized = rel_pos
  358. # Scale the coords with short length if shapes for q and k are different.
  359. q_coords = paddle.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
  360. k_coords = paddle.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
  361. relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
  362. return rel_pos_resized[relative_coords.cast(paddle.int64)]
  363. def add_decomposed_rel_pos(
  364. attn,
  365. q,
  366. rel_pos_h,
  367. rel_pos_w,
  368. q_size: Tuple[int, int],
  369. k_size: Tuple[int, int],
  370. ):
  371. """
  372. Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
  373. https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
  374. Args:
  375. attn (Tensor): attention map.
  376. q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
  377. rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
  378. rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
  379. q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
  380. k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
  381. Returns:
  382. attn (Tensor): attention map with added relative positional embeddings.
  383. """
  384. q_h, q_w = q_size
  385. k_h, k_w = k_size
  386. Rh = get_rel_pos(q_h, k_h, rel_pos_h)
  387. Rw = get_rel_pos(q_w, k_w, rel_pos_w)
  388. B, _, dim = q.shape
  389. r_q = q.reshape([B, q_h, q_w, dim])
  390. rel_h = paddle.einsum("bhwc,hkc->bhwk", r_q, Rh)
  391. rel_w = paddle.einsum("bhwc,wkc->bhwk", r_q, Rw)
  392. attn = (
  393. attn.reshape([B, q_h, q_w, k_h, k_w])
  394. + rel_h[:, :, :, :, None]
  395. + rel_w[:, :, :, None, :]
  396. ).reshape([B, q_h * q_w, k_h * k_w])
  397. return attn
  398. class PatchEmbed(nn.Layer):
  399. """
  400. Image to Patch Embedding.
  401. """
  402. def __init__(
  403. self,
  404. kernel_size: Tuple[int, int] = (16, 16),
  405. stride: Tuple[int, int] = (16, 16),
  406. padding: Tuple[int, int] = (0, 0),
  407. in_chans: int = 3,
  408. embed_dim: int = 768,
  409. ) -> None:
  410. """
  411. Args:
  412. kernel_size (Tuple): kernel size of the projection layer.
  413. stride (Tuple): stride of the projection layer.
  414. padding (Tuple): padding size of the projection layer.
  415. in_chans (int): Number of input image channels.
  416. embed_dim (int): Patch embedding dimension.
  417. """
  418. super().__init__()
  419. self.proj = nn.Conv2D(
  420. in_chans,
  421. embed_dim,
  422. kernel_size=kernel_size,
  423. stride=stride,
  424. padding=padding,
  425. weight_attr=True,
  426. bias_attr=True,
  427. )
  428. def forward(self, x):
  429. x = self.proj(x)
  430. # B C H W -> B H W C
  431. x = x.transpose([0, 2, 3, 1])
  432. return x
  433. def _build_vary(
  434. encoder_embed_dim,
  435. encoder_depth,
  436. encoder_num_heads,
  437. encoder_global_attn_indexes,
  438. image_size,
  439. is_formula=False,
  440. ):
  441. prompt_embed_dim = 256
  442. vit_patch_size = 16
  443. image_embedding_size = image_size // vit_patch_size
  444. image_encoder = ImageEncoderViT(
  445. depth=encoder_depth,
  446. embed_dim=encoder_embed_dim,
  447. img_size=image_size,
  448. mlp_ratio=4,
  449. norm_layer=partial(paddle.nn.LayerNorm, epsilon=1e-6),
  450. num_heads=encoder_num_heads,
  451. patch_size=vit_patch_size,
  452. qkv_bias=True,
  453. use_rel_pos=True,
  454. global_attn_indexes=encoder_global_attn_indexes,
  455. window_size=14,
  456. out_chans=prompt_embed_dim,
  457. is_formula=is_formula,
  458. )
  459. return image_encoder
  460. class Vary_VIT_B(nn.Layer):
  461. def __init__(
  462. self,
  463. in_channels=3,
  464. image_size=768,
  465. encoder_embed_dim=768,
  466. encoder_depth=12,
  467. encoder_num_heads=12,
  468. encoder_global_attn_indexes=[2, 5, 8, 11],
  469. ):
  470. super().__init__()
  471. self.vision_tower_high = _build_vary(
  472. encoder_embed_dim=768,
  473. encoder_depth=12,
  474. encoder_num_heads=12,
  475. encoder_global_attn_indexes=[2, 5, 8, 11],
  476. image_size=image_size,
  477. )
  478. self.out_channels = 1024
  479. def forward(self, input_data):
  480. pixel_values = input_data
  481. num_channels = pixel_values.shape[1]
  482. if num_channels == 1:
  483. pixel_values = paddle.repeat_interleave(pixel_values, repeats=3, axis=1)
  484. cnn_feature = self.vision_tower_high(pixel_values)
  485. cnn_feature = cnn_feature.flatten(2).transpose([0, 2, 1])
  486. return cnn_feature
  487. class Vary_VIT_B_Formula(nn.Layer):
  488. def __init__(
  489. self,
  490. in_channels=3,
  491. image_size=768,
  492. encoder_embed_dim=768,
  493. encoder_depth=12,
  494. encoder_num_heads=12,
  495. encoder_global_attn_indexes=[2, 5, 8, 11],
  496. ):
  497. """
  498. Vary_VIT_B_Formula
  499. Args:
  500. in_channels (int): Number of input channels. Default is 3 (for RGB images).
  501. image_size (int): Size of the input image. Default is 768.
  502. encoder_embed_dim (int): Dimension of the encoder's embedding. Default is 768.
  503. encoder_depth (int): Number of layers (depth) in the encoder. Default is 12.
  504. encoder_num_heads (int): Number of attention heads in the encoder. Default is 12.
  505. encoder_global_attn_indexes (list): List of indices specifying which encoder layers use global attention. Default is [2, 5, 8, 11].
  506. Returns:
  507. model: nn.Layer. Specific `Vary_VIT_B_Formula` model with defined architecture.
  508. """
  509. super(Vary_VIT_B_Formula, self).__init__()
  510. self.vision_tower_high = _build_vary(
  511. encoder_embed_dim=encoder_embed_dim,
  512. encoder_depth=encoder_depth,
  513. encoder_num_heads=encoder_num_heads,
  514. encoder_global_attn_indexes=[2, 5, 8, 11],
  515. image_size=image_size,
  516. is_formula=True,
  517. )
  518. self.mm_projector_vary = nn.Linear(1024, 1024)
  519. self.out_channels = 1024
  520. def forward(self, input_data):
  521. if self.training:
  522. pixel_values, label, attention_mask = input_data
  523. else:
  524. if isinstance(input_data, list):
  525. pixel_values = input_data[0]
  526. else:
  527. pixel_values = input_data
  528. num_channels = pixel_values.shape[1]
  529. if num_channels == 1:
  530. pixel_values = paddle.repeat_interleave(pixel_values, repeats=3, axis=1)
  531. cnn_feature = self.vision_tower_high(pixel_values)
  532. cnn_feature = cnn_feature.flatten(2).transpose([0, 2, 1])
  533. cnn_feature = self.mm_projector_vary(cnn_feature)
  534. donut_swin_output = DonutSwinModelOutput(
  535. last_hidden_state=cnn_feature,
  536. pooler_output=None,
  537. hidden_states=None,
  538. attentions=None,
  539. reshaped_hidden_states=None,
  540. )
  541. if self.training:
  542. return donut_swin_output, label, attention_mask
  543. else:
  544. return donut_swin_output