modeling_swin2sr.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104
  1. # coding=utf-8
  2. # Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Swin2SR Transformer model."""
  16. import collections.abc
  17. import math
  18. from dataclasses import dataclass
  19. from typing import Optional, Union
  20. import torch
  21. from torch import nn
  22. from ...activations import ACT2FN
  23. from ...modeling_layers import GradientCheckpointingLayer
  24. from ...modeling_outputs import BaseModelOutput, ImageSuperResolutionOutput
  25. from ...modeling_utils import PreTrainedModel
  26. from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
  27. from ...utils import ModelOutput, auto_docstring, logging
  28. from .configuration_swin2sr import Swin2SRConfig
  29. logger = logging.get_logger(__name__)
  30. @dataclass
  31. @auto_docstring(
  32. custom_intro="""
  33. Swin2SR encoder's outputs, with potential hidden states and attentions.
  34. """
  35. )
  36. class Swin2SREncoderOutput(ModelOutput):
  37. last_hidden_state: Optional[torch.FloatTensor] = None
  38. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  39. attentions: Optional[tuple[torch.FloatTensor]] = None
  40. # Copied from transformers.models.swin.modeling_swin.window_partition
  41. def window_partition(input_feature, window_size):
  42. """
  43. Partitions the given input into windows.
  44. """
  45. batch_size, height, width, num_channels = input_feature.shape
  46. input_feature = input_feature.view(
  47. batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
  48. )
  49. windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
  50. return windows
  51. # Copied from transformers.models.swin.modeling_swin.window_reverse
  52. def window_reverse(windows, window_size, height, width):
  53. """
  54. Merges windows to produce higher resolution features.
  55. """
  56. num_channels = windows.shape[-1]
  57. windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
  58. windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
  59. return windows
  60. # Copied from transformers.models.beit.modeling_beit.drop_path
  61. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  62. """
  63. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  64. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  65. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  66. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  67. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  68. argument.
  69. """
  70. if drop_prob == 0.0 or not training:
  71. return input
  72. keep_prob = 1 - drop_prob
  73. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  74. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  75. random_tensor.floor_() # binarize
  76. output = input.div(keep_prob) * random_tensor
  77. return output
  78. # Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->Swin2SR
  79. class Swin2SRDropPath(nn.Module):
  80. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  81. def __init__(self, drop_prob: Optional[float] = None) -> None:
  82. super().__init__()
  83. self.drop_prob = drop_prob
  84. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  85. return drop_path(hidden_states, self.drop_prob, self.training)
  86. def extra_repr(self) -> str:
  87. return f"p={self.drop_prob}"
  88. class Swin2SREmbeddings(nn.Module):
  89. """
  90. Construct the patch and optional position embeddings.
  91. """
  92. def __init__(self, config):
  93. super().__init__()
  94. self.patch_embeddings = Swin2SRPatchEmbeddings(config)
  95. num_patches = self.patch_embeddings.num_patches
  96. if config.use_absolute_embeddings:
  97. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
  98. else:
  99. self.position_embeddings = None
  100. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  101. self.window_size = config.window_size
  102. def forward(self, pixel_values: Optional[torch.FloatTensor]) -> tuple[torch.Tensor]:
  103. embeddings, output_dimensions = self.patch_embeddings(pixel_values)
  104. if self.position_embeddings is not None:
  105. embeddings = embeddings + self.position_embeddings
  106. embeddings = self.dropout(embeddings)
  107. return embeddings, output_dimensions
  108. class Swin2SRPatchEmbeddings(nn.Module):
  109. def __init__(self, config, normalize_patches=True):
  110. super().__init__()
  111. num_channels = config.embed_dim
  112. image_size, patch_size = config.image_size, config.patch_size
  113. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  114. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  115. patches_resolution = [image_size[0] // patch_size[0], image_size[1] // patch_size[1]]
  116. self.patches_resolution = patches_resolution
  117. self.num_patches = patches_resolution[0] * patches_resolution[1]
  118. self.projection = nn.Conv2d(num_channels, config.embed_dim, kernel_size=patch_size, stride=patch_size)
  119. self.layernorm = nn.LayerNorm(config.embed_dim) if normalize_patches else None
  120. def forward(self, embeddings: Optional[torch.FloatTensor]) -> tuple[torch.Tensor, tuple[int]]:
  121. embeddings = self.projection(embeddings)
  122. _, _, height, width = embeddings.shape
  123. output_dimensions = (height, width)
  124. embeddings = embeddings.flatten(2).transpose(1, 2)
  125. if self.layernorm is not None:
  126. embeddings = self.layernorm(embeddings)
  127. return embeddings, output_dimensions
  128. class Swin2SRPatchUnEmbeddings(nn.Module):
  129. r"""Image to Patch Unembedding"""
  130. def __init__(self, config):
  131. super().__init__()
  132. self.embed_dim = config.embed_dim
  133. def forward(self, embeddings, x_size):
  134. batch_size, height_width, num_channels = embeddings.shape
  135. embeddings = embeddings.transpose(1, 2).view(batch_size, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
  136. return embeddings
  137. # Copied from transformers.models.swinv2.modeling_swinv2.Swinv2PatchMerging with Swinv2->Swin2SR
  138. class Swin2SRPatchMerging(nn.Module):
  139. """
  140. Patch Merging Layer.
  141. Args:
  142. input_resolution (`tuple[int]`):
  143. Resolution of input feature.
  144. dim (`int`):
  145. Number of input channels.
  146. norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
  147. Normalization layer class.
  148. """
  149. def __init__(self, input_resolution: tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
  150. super().__init__()
  151. self.input_resolution = input_resolution
  152. self.dim = dim
  153. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  154. self.norm = norm_layer(2 * dim)
  155. def maybe_pad(self, input_feature, height, width):
  156. should_pad = (height % 2 == 1) or (width % 2 == 1)
  157. if should_pad:
  158. pad_values = (0, 0, 0, width % 2, 0, height % 2)
  159. input_feature = nn.functional.pad(input_feature, pad_values)
  160. return input_feature
  161. def forward(self, input_feature: torch.Tensor, input_dimensions: tuple[int, int]) -> torch.Tensor:
  162. height, width = input_dimensions
  163. # `dim` is height * width
  164. batch_size, dim, num_channels = input_feature.shape
  165. input_feature = input_feature.view(batch_size, height, width, num_channels)
  166. # pad input to be divisible by width and height, if needed
  167. input_feature = self.maybe_pad(input_feature, height, width)
  168. # [batch_size, height/2, width/2, num_channels]
  169. input_feature_0 = input_feature[:, 0::2, 0::2, :]
  170. # [batch_size, height/2, width/2, num_channels]
  171. input_feature_1 = input_feature[:, 1::2, 0::2, :]
  172. # [batch_size, height/2, width/2, num_channels]
  173. input_feature_2 = input_feature[:, 0::2, 1::2, :]
  174. # [batch_size, height/2, width/2, num_channels]
  175. input_feature_3 = input_feature[:, 1::2, 1::2, :]
  176. # [batch_size, height/2 * width/2, 4*num_channels]
  177. input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
  178. input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # [batch_size, height/2 * width/2, 4*C]
  179. input_feature = self.reduction(input_feature)
  180. input_feature = self.norm(input_feature)
  181. return input_feature
  182. # Copied from transformers.models.swinv2.modeling_swinv2.Swinv2SelfAttention with Swinv2->Swin2SR
  183. class Swin2SRSelfAttention(nn.Module):
  184. def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=[0, 0]):
  185. super().__init__()
  186. if dim % num_heads != 0:
  187. raise ValueError(
  188. f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
  189. )
  190. self.num_attention_heads = num_heads
  191. self.attention_head_size = int(dim / num_heads)
  192. self.all_head_size = self.num_attention_heads * self.attention_head_size
  193. self.window_size = (
  194. window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
  195. )
  196. self.pretrained_window_size = pretrained_window_size
  197. self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
  198. # mlp to generate continuous relative position bias
  199. self.continuous_position_bias_mlp = nn.Sequential(
  200. nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
  201. )
  202. # get relative_coords_table
  203. relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.int64).float()
  204. relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.int64).float()
  205. relative_coords_table = (
  206. torch.stack(meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
  207. .permute(1, 2, 0)
  208. .contiguous()
  209. .unsqueeze(0)
  210. ) # [1, 2*window_height - 1, 2*window_width - 1, 2]
  211. if pretrained_window_size[0] > 0:
  212. relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
  213. relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
  214. elif window_size > 1:
  215. relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
  216. relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
  217. relative_coords_table *= 8 # normalize to -8, 8
  218. relative_coords_table = (
  219. torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8)
  220. )
  221. # set to same dtype as mlp weight
  222. relative_coords_table = relative_coords_table.to(next(self.continuous_position_bias_mlp.parameters()).dtype)
  223. self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)
  224. # get pair-wise relative position index for each token inside the window
  225. coords_h = torch.arange(self.window_size[0])
  226. coords_w = torch.arange(self.window_size[1])
  227. coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
  228. coords_flatten = torch.flatten(coords, 1)
  229. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
  230. relative_coords = relative_coords.permute(1, 2, 0).contiguous()
  231. relative_coords[:, :, 0] += self.window_size[0] - 1
  232. relative_coords[:, :, 1] += self.window_size[1] - 1
  233. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  234. relative_position_index = relative_coords.sum(-1)
  235. self.register_buffer("relative_position_index", relative_position_index, persistent=False)
  236. self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  237. self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=False)
  238. self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  239. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  240. def forward(
  241. self,
  242. hidden_states: torch.Tensor,
  243. attention_mask: Optional[torch.FloatTensor] = None,
  244. head_mask: Optional[torch.FloatTensor] = None,
  245. output_attentions: Optional[bool] = False,
  246. ) -> tuple[torch.Tensor]:
  247. batch_size, dim, num_channels = hidden_states.shape
  248. query_layer = (
  249. self.query(hidden_states)
  250. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  251. .transpose(1, 2)
  252. )
  253. key_layer = (
  254. self.key(hidden_states)
  255. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  256. .transpose(1, 2)
  257. )
  258. value_layer = (
  259. self.value(hidden_states)
  260. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  261. .transpose(1, 2)
  262. )
  263. # cosine attention
  264. attention_scores = nn.functional.normalize(query_layer, dim=-1) @ nn.functional.normalize(
  265. key_layer, dim=-1
  266. ).transpose(-2, -1)
  267. logit_scale = torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp()
  268. attention_scores = attention_scores * logit_scale
  269. relative_position_bias_table = self.continuous_position_bias_mlp(self.relative_coords_table).view(
  270. -1, self.num_attention_heads
  271. )
  272. # [window_height*window_width,window_height*window_width,num_attention_heads]
  273. relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
  274. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
  275. )
  276. # [num_attention_heads,window_height*window_width,window_height*window_width]
  277. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  278. relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
  279. attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
  280. if attention_mask is not None:
  281. # Apply the attention mask is (precomputed for all layers in Swin2SRModel forward() function)
  282. mask_shape = attention_mask.shape[0]
  283. attention_scores = attention_scores.view(
  284. batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
  285. ) + attention_mask.unsqueeze(1).unsqueeze(0)
  286. attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
  287. attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
  288. # Normalize the attention scores to probabilities.
  289. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  290. # This is actually dropping out entire tokens to attend to, which might
  291. # seem a bit unusual, but is taken from the original Transformer paper.
  292. attention_probs = self.dropout(attention_probs)
  293. # Mask heads if we want to
  294. if head_mask is not None:
  295. attention_probs = attention_probs * head_mask
  296. context_layer = torch.matmul(attention_probs, value_layer)
  297. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  298. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  299. context_layer = context_layer.view(new_context_layer_shape)
  300. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  301. return outputs
  302. # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->Swin2SR
  303. class Swin2SRSelfOutput(nn.Module):
  304. def __init__(self, config, dim):
  305. super().__init__()
  306. self.dense = nn.Linear(dim, dim)
  307. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  308. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  309. hidden_states = self.dense(hidden_states)
  310. hidden_states = self.dropout(hidden_states)
  311. return hidden_states
  312. # Copied from transformers.models.swinv2.modeling_swinv2.Swinv2Attention with Swinv2->Swin2SR
  313. class Swin2SRAttention(nn.Module):
  314. def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=0):
  315. super().__init__()
  316. self.self = Swin2SRSelfAttention(
  317. config=config,
  318. dim=dim,
  319. num_heads=num_heads,
  320. window_size=window_size,
  321. pretrained_window_size=pretrained_window_size
  322. if isinstance(pretrained_window_size, collections.abc.Iterable)
  323. else (pretrained_window_size, pretrained_window_size),
  324. )
  325. self.output = Swin2SRSelfOutput(config, dim)
  326. self.pruned_heads = set()
  327. def prune_heads(self, heads):
  328. if len(heads) == 0:
  329. return
  330. heads, index = find_pruneable_heads_and_indices(
  331. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  332. )
  333. # Prune linear layers
  334. self.self.query = prune_linear_layer(self.self.query, index)
  335. self.self.key = prune_linear_layer(self.self.key, index)
  336. self.self.value = prune_linear_layer(self.self.value, index)
  337. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  338. # Update hyper params and store pruned heads
  339. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  340. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  341. self.pruned_heads = self.pruned_heads.union(heads)
  342. def forward(
  343. self,
  344. hidden_states: torch.Tensor,
  345. attention_mask: Optional[torch.FloatTensor] = None,
  346. head_mask: Optional[torch.FloatTensor] = None,
  347. output_attentions: Optional[bool] = False,
  348. ) -> tuple[torch.Tensor]:
  349. self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
  350. attention_output = self.output(self_outputs[0], hidden_states)
  351. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  352. return outputs
  353. # Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->Swin2SR
  354. class Swin2SRIntermediate(nn.Module):
  355. def __init__(self, config, dim):
  356. super().__init__()
  357. self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
  358. if isinstance(config.hidden_act, str):
  359. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  360. else:
  361. self.intermediate_act_fn = config.hidden_act
  362. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  363. hidden_states = self.dense(hidden_states)
  364. hidden_states = self.intermediate_act_fn(hidden_states)
  365. return hidden_states
  366. # Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->Swin2SR
  367. class Swin2SROutput(nn.Module):
  368. def __init__(self, config, dim):
  369. super().__init__()
  370. self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
  371. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  372. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  373. hidden_states = self.dense(hidden_states)
  374. hidden_states = self.dropout(hidden_states)
  375. return hidden_states
  376. # Copied from transformers.models.swinv2.modeling_swinv2.Swinv2Layer with Swinv2->Swin2SR
  377. class Swin2SRLayer(nn.Module):
  378. def __init__(
  379. self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0, pretrained_window_size=0
  380. ):
  381. super().__init__()
  382. self.input_resolution = input_resolution
  383. window_size, shift_size = self._compute_window_shift(
  384. (config.window_size, config.window_size), (shift_size, shift_size)
  385. )
  386. self.window_size = window_size[0]
  387. self.shift_size = shift_size[0]
  388. self.attention = Swin2SRAttention(
  389. config=config,
  390. dim=dim,
  391. num_heads=num_heads,
  392. window_size=self.window_size,
  393. pretrained_window_size=pretrained_window_size
  394. if isinstance(pretrained_window_size, collections.abc.Iterable)
  395. else (pretrained_window_size, pretrained_window_size),
  396. )
  397. self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  398. self.drop_path = Swin2SRDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  399. self.intermediate = Swin2SRIntermediate(config, dim)
  400. self.output = Swin2SROutput(config, dim)
  401. self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  402. def _compute_window_shift(self, target_window_size, target_shift_size) -> tuple[tuple[int, int], tuple[int, int]]:
  403. window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
  404. shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
  405. return window_size, shift_size
  406. def get_attn_mask(self, height, width, dtype):
  407. if self.shift_size > 0:
  408. # calculate attention mask for shifted window multihead self attention
  409. img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
  410. height_slices = (
  411. slice(0, -self.window_size),
  412. slice(-self.window_size, -self.shift_size),
  413. slice(-self.shift_size, None),
  414. )
  415. width_slices = (
  416. slice(0, -self.window_size),
  417. slice(-self.window_size, -self.shift_size),
  418. slice(-self.shift_size, None),
  419. )
  420. count = 0
  421. for height_slice in height_slices:
  422. for width_slice in width_slices:
  423. img_mask[:, height_slice, width_slice, :] = count
  424. count += 1
  425. mask_windows = window_partition(img_mask, self.window_size)
  426. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  427. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  428. attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0)
  429. else:
  430. attn_mask = None
  431. return attn_mask
  432. def maybe_pad(self, hidden_states, height, width):
  433. pad_right = (self.window_size - width % self.window_size) % self.window_size
  434. pad_bottom = (self.window_size - height % self.window_size) % self.window_size
  435. pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
  436. hidden_states = nn.functional.pad(hidden_states, pad_values)
  437. return hidden_states, pad_values
  438. def forward(
  439. self,
  440. hidden_states: torch.Tensor,
  441. input_dimensions: tuple[int, int],
  442. head_mask: Optional[torch.FloatTensor] = None,
  443. output_attentions: Optional[bool] = False,
  444. ) -> tuple[torch.Tensor, torch.Tensor]:
  445. height, width = input_dimensions
  446. batch_size, _, channels = hidden_states.size()
  447. shortcut = hidden_states
  448. # pad hidden_states to multiples of window size
  449. hidden_states = hidden_states.view(batch_size, height, width, channels)
  450. hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
  451. _, height_pad, width_pad, _ = hidden_states.shape
  452. # cyclic shift
  453. if self.shift_size > 0:
  454. shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  455. else:
  456. shifted_hidden_states = hidden_states
  457. # partition windows
  458. hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
  459. hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
  460. attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
  461. if attn_mask is not None:
  462. attn_mask = attn_mask.to(hidden_states_windows.device)
  463. attention_outputs = self.attention(
  464. hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
  465. )
  466. attention_output = attention_outputs[0]
  467. attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
  468. shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
  469. # reverse cyclic shift
  470. if self.shift_size > 0:
  471. attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  472. else:
  473. attention_windows = shifted_windows
  474. was_padded = pad_values[3] > 0 or pad_values[5] > 0
  475. if was_padded:
  476. attention_windows = attention_windows[:, :height, :width, :].contiguous()
  477. attention_windows = attention_windows.view(batch_size, height * width, channels)
  478. hidden_states = self.layernorm_before(attention_windows)
  479. hidden_states = shortcut + self.drop_path(hidden_states)
  480. layer_output = self.intermediate(hidden_states)
  481. layer_output = self.output(layer_output)
  482. layer_output = hidden_states + self.drop_path(self.layernorm_after(layer_output))
  483. layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
  484. return layer_outputs
  485. class Swin2SRStage(GradientCheckpointingLayer):
  486. """
  487. This corresponds to the Residual Swin Transformer Block (RSTB) in the original implementation.
  488. """
  489. def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, pretrained_window_size=0):
  490. super().__init__()
  491. self.config = config
  492. self.dim = dim
  493. self.layers = nn.ModuleList(
  494. [
  495. Swin2SRLayer(
  496. config=config,
  497. dim=dim,
  498. input_resolution=input_resolution,
  499. num_heads=num_heads,
  500. shift_size=0 if (i % 2 == 0) else config.window_size // 2,
  501. pretrained_window_size=pretrained_window_size,
  502. )
  503. for i in range(depth)
  504. ]
  505. )
  506. if config.resi_connection == "1conv":
  507. self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
  508. elif config.resi_connection == "3conv":
  509. # to save parameters and memory
  510. self.conv = nn.Sequential(
  511. nn.Conv2d(dim, dim // 4, 3, 1, 1),
  512. nn.LeakyReLU(negative_slope=0.2, inplace=True),
  513. nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
  514. nn.LeakyReLU(negative_slope=0.2, inplace=True),
  515. nn.Conv2d(dim // 4, dim, 3, 1, 1),
  516. )
  517. self.patch_embed = Swin2SRPatchEmbeddings(config, normalize_patches=False)
  518. self.patch_unembed = Swin2SRPatchUnEmbeddings(config)
  519. def forward(
  520. self,
  521. hidden_states: torch.Tensor,
  522. input_dimensions: tuple[int, int],
  523. head_mask: Optional[torch.FloatTensor] = None,
  524. output_attentions: Optional[bool] = False,
  525. ) -> tuple[torch.Tensor]:
  526. residual = hidden_states
  527. height, width = input_dimensions
  528. for i, layer_module in enumerate(self.layers):
  529. layer_head_mask = head_mask[i] if head_mask is not None else None
  530. layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
  531. hidden_states = layer_outputs[0]
  532. output_dimensions = (height, width, height, width)
  533. hidden_states = self.patch_unembed(hidden_states, input_dimensions)
  534. hidden_states = self.conv(hidden_states)
  535. hidden_states, _ = self.patch_embed(hidden_states)
  536. hidden_states = hidden_states + residual
  537. stage_outputs = (hidden_states, output_dimensions)
  538. if output_attentions:
  539. stage_outputs += layer_outputs[1:]
  540. return stage_outputs
  541. class Swin2SREncoder(nn.Module):
  542. def __init__(self, config, grid_size):
  543. super().__init__()
  544. self.num_stages = len(config.depths)
  545. self.config = config
  546. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
  547. self.stages = nn.ModuleList(
  548. [
  549. Swin2SRStage(
  550. config=config,
  551. dim=config.embed_dim,
  552. input_resolution=(grid_size[0], grid_size[1]),
  553. depth=config.depths[stage_idx],
  554. num_heads=config.num_heads[stage_idx],
  555. drop_path=dpr[sum(config.depths[:stage_idx]) : sum(config.depths[: stage_idx + 1])],
  556. pretrained_window_size=0,
  557. )
  558. for stage_idx in range(self.num_stages)
  559. ]
  560. )
  561. self.gradient_checkpointing = False
  562. def forward(
  563. self,
  564. hidden_states: torch.Tensor,
  565. input_dimensions: tuple[int, int],
  566. head_mask: Optional[torch.FloatTensor] = None,
  567. output_attentions: Optional[bool] = False,
  568. output_hidden_states: Optional[bool] = False,
  569. return_dict: Optional[bool] = True,
  570. ) -> Union[tuple, Swin2SREncoderOutput]:
  571. all_input_dimensions = ()
  572. all_hidden_states = () if output_hidden_states else None
  573. all_self_attentions = () if output_attentions else None
  574. if output_hidden_states:
  575. all_hidden_states += (hidden_states,)
  576. for i, stage_module in enumerate(self.stages):
  577. layer_head_mask = head_mask[i] if head_mask is not None else None
  578. layer_outputs = stage_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
  579. hidden_states = layer_outputs[0]
  580. output_dimensions = layer_outputs[1]
  581. input_dimensions = (output_dimensions[-2], output_dimensions[-1])
  582. all_input_dimensions += (input_dimensions,)
  583. if output_hidden_states:
  584. all_hidden_states += (hidden_states,)
  585. if output_attentions:
  586. all_self_attentions += layer_outputs[2:]
  587. if not return_dict:
  588. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  589. return Swin2SREncoderOutput(
  590. last_hidden_state=hidden_states,
  591. hidden_states=all_hidden_states,
  592. attentions=all_self_attentions,
  593. )
  594. @auto_docstring
  595. class Swin2SRPreTrainedModel(PreTrainedModel):
  596. config: Swin2SRConfig
  597. base_model_prefix = "swin2sr"
  598. main_input_name = "pixel_values"
  599. supports_gradient_checkpointing = True
  600. def _init_weights(self, module):
  601. """Initialize the weights"""
  602. if isinstance(module, (nn.Linear, nn.Conv2d)):
  603. torch.nn.init.trunc_normal_(module.weight.data, std=self.config.initializer_range)
  604. if module.bias is not None:
  605. module.bias.data.zero_()
  606. elif isinstance(module, nn.LayerNorm):
  607. module.bias.data.zero_()
  608. module.weight.data.fill_(1.0)
  609. @auto_docstring
  610. class Swin2SRModel(Swin2SRPreTrainedModel):
  611. def __init__(self, config):
  612. super().__init__(config)
  613. self.config = config
  614. if config.num_channels == 3 and config.num_channels_out == 3:
  615. mean = torch.tensor([0.4488, 0.4371, 0.4040]).view(1, 3, 1, 1)
  616. else:
  617. mean = torch.zeros(1, 1, 1, 1)
  618. self.register_buffer("mean", mean, persistent=False)
  619. self.img_range = config.img_range
  620. self.first_convolution = nn.Conv2d(config.num_channels, config.embed_dim, 3, 1, 1)
  621. self.embeddings = Swin2SREmbeddings(config)
  622. self.encoder = Swin2SREncoder(config, grid_size=self.embeddings.patch_embeddings.patches_resolution)
  623. self.layernorm = nn.LayerNorm(config.embed_dim, eps=config.layer_norm_eps)
  624. self.patch_unembed = Swin2SRPatchUnEmbeddings(config)
  625. self.conv_after_body = nn.Conv2d(config.embed_dim, config.embed_dim, 3, 1, 1)
  626. # Initialize weights and apply final processing
  627. self.post_init()
  628. def get_input_embeddings(self):
  629. return self.embeddings.patch_embeddings
  630. def _prune_heads(self, heads_to_prune):
  631. """
  632. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  633. class PreTrainedModel
  634. """
  635. for layer, heads in heads_to_prune.items():
  636. self.encoder.layer[layer].attention.prune_heads(heads)
  637. def pad_and_normalize(self, pixel_values):
  638. _, _, height, width = pixel_values.size()
  639. # 1. pad
  640. window_size = self.config.window_size
  641. modulo_pad_height = (window_size - height % window_size) % window_size
  642. modulo_pad_width = (window_size - width % window_size) % window_size
  643. pixel_values = nn.functional.pad(pixel_values, (0, modulo_pad_width, 0, modulo_pad_height), "reflect")
  644. # 2. normalize
  645. mean = self.mean.type_as(pixel_values)
  646. pixel_values = (pixel_values - mean) * self.img_range
  647. return pixel_values
  648. @auto_docstring
  649. def forward(
  650. self,
  651. pixel_values: torch.FloatTensor,
  652. head_mask: Optional[torch.FloatTensor] = None,
  653. output_attentions: Optional[bool] = None,
  654. output_hidden_states: Optional[bool] = None,
  655. return_dict: Optional[bool] = None,
  656. ) -> Union[tuple, BaseModelOutput]:
  657. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  658. output_hidden_states = (
  659. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  660. )
  661. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  662. # Prepare head mask if needed
  663. # 1.0 in head_mask indicate we keep the head
  664. # attention_probs has shape bsz x n_heads x N x N
  665. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  666. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  667. head_mask = self.get_head_mask(head_mask, len(self.config.depths))
  668. _, _, height, width = pixel_values.shape
  669. # some preprocessing: padding + normalization
  670. pixel_values = self.pad_and_normalize(pixel_values)
  671. embeddings = self.first_convolution(pixel_values)
  672. embedding_output, input_dimensions = self.embeddings(embeddings)
  673. encoder_outputs = self.encoder(
  674. embedding_output,
  675. input_dimensions,
  676. head_mask=head_mask,
  677. output_attentions=output_attentions,
  678. output_hidden_states=output_hidden_states,
  679. return_dict=return_dict,
  680. )
  681. sequence_output = encoder_outputs[0]
  682. sequence_output = self.layernorm(sequence_output)
  683. sequence_output = self.patch_unembed(sequence_output, (height, width))
  684. sequence_output = self.conv_after_body(sequence_output) + embeddings
  685. if not return_dict:
  686. output = (sequence_output,) + encoder_outputs[1:]
  687. return output
  688. return BaseModelOutput(
  689. last_hidden_state=sequence_output,
  690. hidden_states=encoder_outputs.hidden_states,
  691. attentions=encoder_outputs.attentions,
  692. )
  693. class Upsample(nn.Module):
  694. """Upsample module.
  695. Args:
  696. scale (`int`):
  697. Scale factor. Supported scales: 2^n and 3.
  698. num_features (`int`):
  699. Channel number of intermediate features.
  700. """
  701. def __init__(self, scale, num_features):
  702. super().__init__()
  703. self.scale = scale
  704. if (scale & (scale - 1)) == 0:
  705. # scale = 2^n
  706. for i in range(int(math.log2(scale))):
  707. self.add_module(f"convolution_{i}", nn.Conv2d(num_features, 4 * num_features, 3, 1, 1))
  708. self.add_module(f"pixelshuffle_{i}", nn.PixelShuffle(2))
  709. elif scale == 3:
  710. self.convolution = nn.Conv2d(num_features, 9 * num_features, 3, 1, 1)
  711. self.pixelshuffle = nn.PixelShuffle(3)
  712. else:
  713. raise ValueError(f"Scale {scale} is not supported. Supported scales: 2^n and 3.")
  714. def forward(self, hidden_state):
  715. if (self.scale & (self.scale - 1)) == 0:
  716. for i in range(int(math.log2(self.scale))):
  717. hidden_state = self.__getattr__(f"convolution_{i}")(hidden_state)
  718. hidden_state = self.__getattr__(f"pixelshuffle_{i}")(hidden_state)
  719. elif self.scale == 3:
  720. hidden_state = self.convolution(hidden_state)
  721. hidden_state = self.pixelshuffle(hidden_state)
  722. return hidden_state
  723. class UpsampleOneStep(nn.Module):
  724. """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
  725. Used in lightweight SR to save parameters.
  726. Args:
  727. scale (int):
  728. Scale factor. Supported scales: 2^n and 3.
  729. in_channels (int):
  730. Channel number of intermediate features.
  731. out_channels (int):
  732. Channel number of output features.
  733. """
  734. def __init__(self, scale, in_channels, out_channels):
  735. super().__init__()
  736. self.conv = nn.Conv2d(in_channels, (scale**2) * out_channels, 3, 1, 1)
  737. self.pixel_shuffle = nn.PixelShuffle(scale)
  738. def forward(self, x):
  739. x = self.conv(x)
  740. x = self.pixel_shuffle(x)
  741. return x
  742. class PixelShuffleUpsampler(nn.Module):
  743. def __init__(self, config, num_features):
  744. super().__init__()
  745. self.conv_before_upsample = nn.Conv2d(config.embed_dim, num_features, 3, 1, 1)
  746. self.activation = nn.LeakyReLU(inplace=True)
  747. self.upsample = Upsample(config.upscale, num_features)
  748. self.final_convolution = nn.Conv2d(num_features, config.num_channels_out, 3, 1, 1)
  749. def forward(self, sequence_output):
  750. x = self.conv_before_upsample(sequence_output)
  751. x = self.activation(x)
  752. x = self.upsample(x)
  753. x = self.final_convolution(x)
  754. return x
  755. class NearestConvUpsampler(nn.Module):
  756. def __init__(self, config, num_features):
  757. super().__init__()
  758. if config.upscale != 4:
  759. raise ValueError("The nearest+conv upsampler only supports an upscale factor of 4 at the moment.")
  760. self.conv_before_upsample = nn.Conv2d(config.embed_dim, num_features, 3, 1, 1)
  761. self.activation = nn.LeakyReLU(inplace=True)
  762. self.conv_up1 = nn.Conv2d(num_features, num_features, 3, 1, 1)
  763. self.conv_up2 = nn.Conv2d(num_features, num_features, 3, 1, 1)
  764. self.conv_hr = nn.Conv2d(num_features, num_features, 3, 1, 1)
  765. self.final_convolution = nn.Conv2d(num_features, config.num_channels_out, 3, 1, 1)
  766. self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
  767. def forward(self, sequence_output):
  768. sequence_output = self.conv_before_upsample(sequence_output)
  769. sequence_output = self.activation(sequence_output)
  770. sequence_output = self.lrelu(
  771. self.conv_up1(torch.nn.functional.interpolate(sequence_output, scale_factor=2, mode="nearest"))
  772. )
  773. sequence_output = self.lrelu(
  774. self.conv_up2(torch.nn.functional.interpolate(sequence_output, scale_factor=2, mode="nearest"))
  775. )
  776. reconstruction = self.final_convolution(self.lrelu(self.conv_hr(sequence_output)))
  777. return reconstruction
  778. class PixelShuffleAuxUpsampler(nn.Module):
  779. def __init__(self, config, num_features):
  780. super().__init__()
  781. self.upscale = config.upscale
  782. self.conv_bicubic = nn.Conv2d(config.num_channels, num_features, 3, 1, 1)
  783. self.conv_before_upsample = nn.Conv2d(config.embed_dim, num_features, 3, 1, 1)
  784. self.activation = nn.LeakyReLU(inplace=True)
  785. self.conv_aux = nn.Conv2d(num_features, config.num_channels, 3, 1, 1)
  786. self.conv_after_aux = nn.Sequential(nn.Conv2d(3, num_features, 3, 1, 1), nn.LeakyReLU(inplace=True))
  787. self.upsample = Upsample(config.upscale, num_features)
  788. self.final_convolution = nn.Conv2d(num_features, config.num_channels_out, 3, 1, 1)
  789. def forward(self, sequence_output, bicubic, height, width):
  790. bicubic = self.conv_bicubic(bicubic)
  791. sequence_output = self.conv_before_upsample(sequence_output)
  792. sequence_output = self.activation(sequence_output)
  793. aux = self.conv_aux(sequence_output)
  794. sequence_output = self.conv_after_aux(aux)
  795. sequence_output = (
  796. self.upsample(sequence_output)[:, :, : height * self.upscale, : width * self.upscale]
  797. + bicubic[:, :, : height * self.upscale, : width * self.upscale]
  798. )
  799. reconstruction = self.final_convolution(sequence_output)
  800. return reconstruction, aux
  801. @auto_docstring(
  802. custom_intro="""
  803. Swin2SR Model transformer with an upsampler head on top for image super resolution and restoration.
  804. """
  805. )
  806. class Swin2SRForImageSuperResolution(Swin2SRPreTrainedModel):
  807. def __init__(self, config):
  808. super().__init__(config)
  809. self.swin2sr = Swin2SRModel(config)
  810. self.upsampler = config.upsampler
  811. self.upscale = config.upscale
  812. # Upsampler
  813. num_features = 64
  814. if self.upsampler == "pixelshuffle":
  815. self.upsample = PixelShuffleUpsampler(config, num_features)
  816. elif self.upsampler == "pixelshuffle_aux":
  817. self.upsample = PixelShuffleAuxUpsampler(config, num_features)
  818. elif self.upsampler == "pixelshuffledirect":
  819. # for lightweight SR (to save parameters)
  820. self.upsample = UpsampleOneStep(config.upscale, config.embed_dim, config.num_channels_out)
  821. elif self.upsampler == "nearest+conv":
  822. # for real-world SR (less artifacts)
  823. self.upsample = NearestConvUpsampler(config, num_features)
  824. else:
  825. # for image denoising and JPEG compression artifact reduction
  826. self.final_convolution = nn.Conv2d(config.embed_dim, config.num_channels_out, 3, 1, 1)
  827. # Initialize weights and apply final processing
  828. self.post_init()
  829. @auto_docstring
  830. def forward(
  831. self,
  832. pixel_values: Optional[torch.FloatTensor] = None,
  833. head_mask: Optional[torch.FloatTensor] = None,
  834. labels: Optional[torch.LongTensor] = None,
  835. output_attentions: Optional[bool] = None,
  836. output_hidden_states: Optional[bool] = None,
  837. return_dict: Optional[bool] = None,
  838. ) -> Union[tuple, ImageSuperResolutionOutput]:
  839. r"""
  840. Example:
  841. ```python
  842. >>> import torch
  843. >>> import numpy as np
  844. >>> from PIL import Image
  845. >>> import requests
  846. >>> from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
  847. >>> processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
  848. >>> model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
  849. >>> url = "https://huggingface.co/spaces/jjourney1125/swin2sr/resolve/main/samples/butterfly.jpg"
  850. >>> image = Image.open(requests.get(url, stream=True).raw)
  851. >>> # prepare image for the model
  852. >>> inputs = processor(image, return_tensors="pt")
  853. >>> # forward pass
  854. >>> with torch.no_grad():
  855. ... outputs = model(**inputs)
  856. >>> output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  857. >>> output = np.moveaxis(output, source=0, destination=-1)
  858. >>> output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
  859. >>> # you can visualize `output` with `Image.fromarray`
  860. ```"""
  861. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  862. loss = None
  863. if labels is not None:
  864. raise NotImplementedError("Training is not supported at the moment")
  865. height, width = pixel_values.shape[2:]
  866. if self.config.upsampler == "pixelshuffle_aux":
  867. bicubic = nn.functional.interpolate(
  868. pixel_values,
  869. size=(height * self.upscale, width * self.upscale),
  870. mode="bicubic",
  871. align_corners=False,
  872. )
  873. outputs = self.swin2sr(
  874. pixel_values,
  875. head_mask=head_mask,
  876. output_attentions=output_attentions,
  877. output_hidden_states=output_hidden_states,
  878. return_dict=return_dict,
  879. )
  880. sequence_output = outputs[0]
  881. if self.upsampler in ["pixelshuffle", "pixelshuffledirect", "nearest+conv"]:
  882. reconstruction = self.upsample(sequence_output)
  883. elif self.upsampler == "pixelshuffle_aux":
  884. reconstruction, aux = self.upsample(sequence_output, bicubic, height, width)
  885. aux = aux / self.swin2sr.img_range + self.swin2sr.mean
  886. else:
  887. reconstruction = pixel_values + self.final_convolution(sequence_output)
  888. reconstruction = reconstruction / self.swin2sr.img_range + self.swin2sr.mean
  889. reconstruction = reconstruction[:, :, : height * self.upscale, : width * self.upscale]
  890. if not return_dict:
  891. output = (reconstruction,) + outputs[1:]
  892. return ((loss,) + output) if loss is not None else output
  893. return ImageSuperResolutionOutput(
  894. loss=loss,
  895. reconstruction=reconstruction,
  896. hidden_states=outputs.hidden_states,
  897. attentions=outputs.attentions,
  898. )
  899. __all__ = ["Swin2SRForImageSuperResolution", "Swin2SRModel", "Swin2SRPreTrainedModel"]