modeling_swin.py 53 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278
  1. # coding=utf-8
  2. # Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Swin Transformer model."""
  16. import collections.abc
  17. import math
  18. import warnings
  19. from dataclasses import dataclass
  20. from typing import Optional, Union
  21. import torch
  22. from torch import nn
  23. from ...activations import ACT2FN
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import BackboneOutput
  26. from ...modeling_utils import PreTrainedModel
  27. from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
  28. from ...utils import ModelOutput, auto_docstring, logging, torch_int
  29. from ...utils.backbone_utils import BackboneMixin
  30. from .configuration_swin import SwinConfig
  31. logger = logging.get_logger(__name__)
  32. # drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library.
  33. @dataclass
  34. @auto_docstring(
  35. custom_intro="""
  36. Swin encoder's outputs, with potential hidden states and attentions.
  37. """
  38. )
  39. class SwinEncoderOutput(ModelOutput):
  40. r"""
  41. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  42. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  43. shape `(batch_size, hidden_size, height, width)`.
  44. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  45. include the spatial dimensions.
  46. """
  47. last_hidden_state: Optional[torch.FloatTensor] = None
  48. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  49. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  50. reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  51. @dataclass
  52. @auto_docstring(
  53. custom_intro="""
  54. Swin model's outputs that also contains a pooling of the last hidden states.
  55. """
  56. )
  57. class SwinModelOutput(ModelOutput):
  58. r"""
  59. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
  60. Average pooling of the last layer hidden-state.
  61. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  62. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  63. shape `(batch_size, hidden_size, height, width)`.
  64. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  65. include the spatial dimensions.
  66. """
  67. last_hidden_state: Optional[torch.FloatTensor] = None
  68. pooler_output: Optional[torch.FloatTensor] = None
  69. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  70. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  71. reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  72. @dataclass
  73. @auto_docstring(
  74. custom_intro="""
  75. Swin masked image model outputs.
  76. """
  77. )
  78. class SwinMaskedImageModelingOutput(ModelOutput):
  79. r"""
  80. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
  81. Masked image modeling (MLM) loss.
  82. reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  83. Reconstructed pixel values.
  84. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  85. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  86. shape `(batch_size, hidden_size, height, width)`.
  87. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  88. include the spatial dimensions.
  89. """
  90. loss: Optional[torch.FloatTensor] = None
  91. reconstruction: Optional[torch.FloatTensor] = None
  92. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  93. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  94. reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  95. @property
  96. def logits(self):
  97. warnings.warn(
  98. "logits attribute is deprecated and will be removed in version 5 of Transformers."
  99. " Please use the reconstruction attribute to retrieve the final output instead.",
  100. FutureWarning,
  101. )
  102. return self.reconstruction
  103. @dataclass
  104. @auto_docstring(
  105. custom_intro="""
  106. Swin outputs for image classification.
  107. """
  108. )
  109. class SwinImageClassifierOutput(ModelOutput):
  110. r"""
  111. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  112. Classification (or regression if config.num_labels==1) loss.
  113. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  114. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  115. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  116. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  117. shape `(batch_size, hidden_size, height, width)`.
  118. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  119. include the spatial dimensions.
  120. """
  121. loss: Optional[torch.FloatTensor] = None
  122. logits: Optional[torch.FloatTensor] = None
  123. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  124. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  125. reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  126. def window_partition(input_feature, window_size):
  127. """
  128. Partitions the given input into windows.
  129. """
  130. batch_size, height, width, num_channels = input_feature.shape
  131. input_feature = input_feature.view(
  132. batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
  133. )
  134. windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
  135. return windows
  136. def window_reverse(windows, window_size, height, width):
  137. """
  138. Merges windows to produce higher resolution features.
  139. """
  140. num_channels = windows.shape[-1]
  141. windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
  142. windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
  143. return windows
  144. class SwinEmbeddings(nn.Module):
  145. """
  146. Construct the patch and position embeddings. Optionally, also the mask token.
  147. """
  148. def __init__(self, config, use_mask_token=False):
  149. super().__init__()
  150. self.patch_embeddings = SwinPatchEmbeddings(config)
  151. num_patches = self.patch_embeddings.num_patches
  152. self.patch_grid = self.patch_embeddings.grid_size
  153. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
  154. if config.use_absolute_embeddings:
  155. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
  156. else:
  157. self.position_embeddings = None
  158. self.norm = nn.LayerNorm(config.embed_dim)
  159. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  160. self.patch_size = config.patch_size
  161. self.config = config
  162. # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
  163. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  164. """
  165. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  166. images. This method is also adapted to support torch.jit tracing.
  167. Adapted from:
  168. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  169. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  170. """
  171. num_patches = embeddings.shape[1] - 1
  172. num_positions = self.position_embeddings.shape[1] - 1
  173. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  174. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  175. return self.position_embeddings
  176. class_pos_embed = self.position_embeddings[:, :1]
  177. patch_pos_embed = self.position_embeddings[:, 1:]
  178. dim = embeddings.shape[-1]
  179. new_height = height // self.patch_size
  180. new_width = width // self.patch_size
  181. sqrt_num_positions = torch_int(num_positions**0.5)
  182. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  183. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  184. patch_pos_embed = nn.functional.interpolate(
  185. patch_pos_embed,
  186. size=(new_height, new_width),
  187. mode="bicubic",
  188. align_corners=False,
  189. )
  190. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  191. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  192. def forward(
  193. self,
  194. pixel_values: Optional[torch.FloatTensor],
  195. bool_masked_pos: Optional[torch.BoolTensor] = None,
  196. interpolate_pos_encoding: bool = False,
  197. ) -> tuple[torch.Tensor]:
  198. _, num_channels, height, width = pixel_values.shape
  199. embeddings, output_dimensions = self.patch_embeddings(pixel_values)
  200. embeddings = self.norm(embeddings)
  201. batch_size, seq_len, _ = embeddings.size()
  202. if bool_masked_pos is not None:
  203. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  204. # replace the masked visual tokens by mask_tokens
  205. mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  206. embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
  207. if self.position_embeddings is not None:
  208. if interpolate_pos_encoding:
  209. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  210. else:
  211. embeddings = embeddings + self.position_embeddings
  212. embeddings = self.dropout(embeddings)
  213. return embeddings, output_dimensions
  214. class SwinPatchEmbeddings(nn.Module):
  215. """
  216. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  217. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  218. Transformer.
  219. """
  220. def __init__(self, config):
  221. super().__init__()
  222. image_size, patch_size = config.image_size, config.patch_size
  223. num_channels, hidden_size = config.num_channels, config.embed_dim
  224. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  225. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  226. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  227. self.image_size = image_size
  228. self.patch_size = patch_size
  229. self.num_channels = num_channels
  230. self.num_patches = num_patches
  231. self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  232. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  233. def maybe_pad(self, pixel_values, height, width):
  234. if width % self.patch_size[1] != 0:
  235. pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
  236. pixel_values = nn.functional.pad(pixel_values, pad_values)
  237. if height % self.patch_size[0] != 0:
  238. pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
  239. pixel_values = nn.functional.pad(pixel_values, pad_values)
  240. return pixel_values
  241. def forward(self, pixel_values: Optional[torch.FloatTensor]) -> tuple[torch.Tensor, tuple[int]]:
  242. _, num_channels, height, width = pixel_values.shape
  243. # pad the input to be divisible by self.patch_size, if needed
  244. pixel_values = self.maybe_pad(pixel_values, height, width)
  245. embeddings = self.projection(pixel_values)
  246. _, _, height, width = embeddings.shape
  247. output_dimensions = (height, width)
  248. embeddings = embeddings.flatten(2).transpose(1, 2)
  249. return embeddings, output_dimensions
  250. class SwinPatchMerging(nn.Module):
  251. """
  252. Patch Merging Layer.
  253. Args:
  254. input_resolution (`tuple[int]`):
  255. Resolution of input feature.
  256. dim (`int`):
  257. Number of input channels.
  258. norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
  259. Normalization layer class.
  260. """
  261. def __init__(self, input_resolution: tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
  262. super().__init__()
  263. self.input_resolution = input_resolution
  264. self.dim = dim
  265. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  266. self.norm = norm_layer(4 * dim)
  267. def maybe_pad(self, input_feature, height, width):
  268. should_pad = (height % 2 == 1) or (width % 2 == 1)
  269. if should_pad:
  270. pad_values = (0, 0, 0, width % 2, 0, height % 2)
  271. input_feature = nn.functional.pad(input_feature, pad_values)
  272. return input_feature
  273. def forward(self, input_feature: torch.Tensor, input_dimensions: tuple[int, int]) -> torch.Tensor:
  274. height, width = input_dimensions
  275. # `dim` is height * width
  276. batch_size, dim, num_channels = input_feature.shape
  277. input_feature = input_feature.view(batch_size, height, width, num_channels)
  278. # pad input to be divisible by width and height, if needed
  279. input_feature = self.maybe_pad(input_feature, height, width)
  280. # [batch_size, height/2, width/2, num_channels]
  281. input_feature_0 = input_feature[:, 0::2, 0::2, :]
  282. # [batch_size, height/2, width/2, num_channels]
  283. input_feature_1 = input_feature[:, 1::2, 0::2, :]
  284. # [batch_size, height/2, width/2, num_channels]
  285. input_feature_2 = input_feature[:, 0::2, 1::2, :]
  286. # [batch_size, height/2, width/2, num_channels]
  287. input_feature_3 = input_feature[:, 1::2, 1::2, :]
  288. # batch_size height/2 width/2 4*num_channels
  289. input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
  290. input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C
  291. input_feature = self.norm(input_feature)
  292. input_feature = self.reduction(input_feature)
  293. return input_feature
  294. # Copied from transformers.models.beit.modeling_beit.drop_path
  295. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  296. """
  297. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  298. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  299. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  300. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  301. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  302. argument.
  303. """
  304. if drop_prob == 0.0 or not training:
  305. return input
  306. keep_prob = 1 - drop_prob
  307. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  308. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  309. random_tensor.floor_() # binarize
  310. output = input.div(keep_prob) * random_tensor
  311. return output
  312. # Copied from transformers.models.beit.modeling_beit.BeitDropPath with Beit->Swin
  313. class SwinDropPath(nn.Module):
  314. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  315. def __init__(self, drop_prob: Optional[float] = None) -> None:
  316. super().__init__()
  317. self.drop_prob = drop_prob
  318. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  319. return drop_path(hidden_states, self.drop_prob, self.training)
  320. def extra_repr(self) -> str:
  321. return f"p={self.drop_prob}"
  322. class SwinSelfAttention(nn.Module):
  323. def __init__(self, config, dim, num_heads, window_size):
  324. super().__init__()
  325. if dim % num_heads != 0:
  326. raise ValueError(
  327. f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
  328. )
  329. self.num_attention_heads = num_heads
  330. self.attention_head_size = int(dim / num_heads)
  331. self.all_head_size = self.num_attention_heads * self.attention_head_size
  332. self.window_size = (
  333. window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
  334. )
  335. self.relative_position_bias_table = nn.Parameter(
  336. torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
  337. )
  338. # get pair-wise relative position index for each token inside the window
  339. coords_h = torch.arange(self.window_size[0])
  340. coords_w = torch.arange(self.window_size[1])
  341. coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
  342. coords_flatten = torch.flatten(coords, 1)
  343. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
  344. relative_coords = relative_coords.permute(1, 2, 0).contiguous()
  345. relative_coords[:, :, 0] += self.window_size[0] - 1
  346. relative_coords[:, :, 1] += self.window_size[1] - 1
  347. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  348. relative_position_index = relative_coords.sum(-1)
  349. self.register_buffer("relative_position_index", relative_position_index)
  350. self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  351. self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  352. self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  353. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  354. def forward(
  355. self,
  356. hidden_states: torch.Tensor,
  357. attention_mask: Optional[torch.FloatTensor] = None,
  358. head_mask: Optional[torch.FloatTensor] = None,
  359. output_attentions: Optional[bool] = False,
  360. ) -> tuple[torch.Tensor]:
  361. batch_size, dim, num_channels = hidden_states.shape
  362. hidden_shape = (batch_size, dim, -1, self.attention_head_size)
  363. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  364. key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  365. value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  366. # Take the dot product between "query" and "key" to get the raw attention scores.
  367. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  368. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  369. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
  370. relative_position_bias = relative_position_bias.view(
  371. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
  372. )
  373. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
  374. attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
  375. if attention_mask is not None:
  376. # Apply the attention mask is (precomputed for all layers in SwinModel forward() function)
  377. mask_shape = attention_mask.shape[0]
  378. attention_scores = attention_scores.view(
  379. batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
  380. )
  381. attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
  382. attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
  383. # Normalize the attention scores to probabilities.
  384. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  385. # This is actually dropping out entire tokens to attend to, which might
  386. # seem a bit unusual, but is taken from the original Transformer paper.
  387. attention_probs = self.dropout(attention_probs)
  388. # Mask heads if we want to
  389. if head_mask is not None:
  390. attention_probs = attention_probs * head_mask
  391. context_layer = torch.matmul(attention_probs, value_layer)
  392. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  393. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  394. context_layer = context_layer.view(new_context_layer_shape)
  395. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  396. return outputs
  397. class SwinSelfOutput(nn.Module):
  398. def __init__(self, config, dim):
  399. super().__init__()
  400. self.dense = nn.Linear(dim, dim)
  401. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  402. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  403. hidden_states = self.dense(hidden_states)
  404. hidden_states = self.dropout(hidden_states)
  405. return hidden_states
  406. class SwinAttention(nn.Module):
  407. def __init__(self, config, dim, num_heads, window_size):
  408. super().__init__()
  409. self.self = SwinSelfAttention(config, dim, num_heads, window_size)
  410. self.output = SwinSelfOutput(config, dim)
  411. self.pruned_heads = set()
  412. def prune_heads(self, heads):
  413. if len(heads) == 0:
  414. return
  415. heads, index = find_pruneable_heads_and_indices(
  416. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  417. )
  418. # Prune linear layers
  419. self.self.query = prune_linear_layer(self.self.query, index)
  420. self.self.key = prune_linear_layer(self.self.key, index)
  421. self.self.value = prune_linear_layer(self.self.value, index)
  422. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  423. # Update hyper params and store pruned heads
  424. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  425. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  426. self.pruned_heads = self.pruned_heads.union(heads)
  427. def forward(
  428. self,
  429. hidden_states: torch.Tensor,
  430. attention_mask: Optional[torch.FloatTensor] = None,
  431. head_mask: Optional[torch.FloatTensor] = None,
  432. output_attentions: Optional[bool] = False,
  433. ) -> tuple[torch.Tensor]:
  434. self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
  435. attention_output = self.output(self_outputs[0], hidden_states)
  436. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  437. return outputs
  438. class SwinIntermediate(nn.Module):
  439. def __init__(self, config, dim):
  440. super().__init__()
  441. self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
  442. if isinstance(config.hidden_act, str):
  443. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  444. else:
  445. self.intermediate_act_fn = config.hidden_act
  446. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  447. hidden_states = self.dense(hidden_states)
  448. hidden_states = self.intermediate_act_fn(hidden_states)
  449. return hidden_states
  450. class SwinOutput(nn.Module):
  451. def __init__(self, config, dim):
  452. super().__init__()
  453. self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
  454. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  455. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  456. hidden_states = self.dense(hidden_states)
  457. hidden_states = self.dropout(hidden_states)
  458. return hidden_states
  459. class SwinLayer(nn.Module):
  460. def __init__(self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0):
  461. super().__init__()
  462. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  463. self.shift_size = shift_size
  464. self.window_size = config.window_size
  465. self.input_resolution = input_resolution
  466. self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  467. self.attention = SwinAttention(config, dim, num_heads, window_size=self.window_size)
  468. self.drop_path = SwinDropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  469. self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  470. self.intermediate = SwinIntermediate(config, dim)
  471. self.output = SwinOutput(config, dim)
  472. def set_shift_and_window_size(self, input_resolution):
  473. if min(input_resolution) <= self.window_size:
  474. # if window size is larger than input resolution, we don't partition windows
  475. self.shift_size = torch_int(0)
  476. self.window_size = (
  477. torch.min(torch.tensor(input_resolution)) if torch.jit.is_tracing() else min(input_resolution)
  478. )
  479. def get_attn_mask(self, height, width, dtype, device):
  480. if self.shift_size > 0:
  481. # calculate attention mask for SW-MSA
  482. img_mask = torch.zeros((1, height, width, 1), dtype=dtype, device=device)
  483. height_slices = (
  484. slice(0, -self.window_size),
  485. slice(-self.window_size, -self.shift_size),
  486. slice(-self.shift_size, None),
  487. )
  488. width_slices = (
  489. slice(0, -self.window_size),
  490. slice(-self.window_size, -self.shift_size),
  491. slice(-self.shift_size, None),
  492. )
  493. count = 0
  494. for height_slice in height_slices:
  495. for width_slice in width_slices:
  496. img_mask[:, height_slice, width_slice, :] = count
  497. count += 1
  498. mask_windows = window_partition(img_mask, self.window_size)
  499. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  500. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  501. attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0)
  502. else:
  503. attn_mask = None
  504. return attn_mask
  505. def maybe_pad(self, hidden_states, height, width):
  506. pad_right = (self.window_size - width % self.window_size) % self.window_size
  507. pad_bottom = (self.window_size - height % self.window_size) % self.window_size
  508. pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
  509. hidden_states = nn.functional.pad(hidden_states, pad_values)
  510. return hidden_states, pad_values
  511. def forward(
  512. self,
  513. hidden_states: torch.Tensor,
  514. input_dimensions: tuple[int, int],
  515. head_mask: Optional[torch.FloatTensor] = None,
  516. output_attentions: Optional[bool] = False,
  517. always_partition: Optional[bool] = False,
  518. ) -> tuple[torch.Tensor, torch.Tensor]:
  519. if not always_partition:
  520. self.set_shift_and_window_size(input_dimensions)
  521. else:
  522. pass
  523. height, width = input_dimensions
  524. batch_size, _, channels = hidden_states.size()
  525. shortcut = hidden_states
  526. hidden_states = self.layernorm_before(hidden_states)
  527. hidden_states = hidden_states.view(batch_size, height, width, channels)
  528. # pad hidden_states to multiples of window size
  529. hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
  530. _, height_pad, width_pad, _ = hidden_states.shape
  531. # cyclic shift
  532. if self.shift_size > 0:
  533. shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  534. else:
  535. shifted_hidden_states = hidden_states
  536. # partition windows
  537. hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
  538. hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
  539. attn_mask = self.get_attn_mask(
  540. height_pad, width_pad, dtype=hidden_states.dtype, device=hidden_states_windows.device
  541. )
  542. attention_outputs = self.attention(
  543. hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
  544. )
  545. attention_output = attention_outputs[0]
  546. attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
  547. shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
  548. # reverse cyclic shift
  549. if self.shift_size > 0:
  550. attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  551. else:
  552. attention_windows = shifted_windows
  553. was_padded = pad_values[3] > 0 or pad_values[5] > 0
  554. if was_padded:
  555. attention_windows = attention_windows[:, :height, :width, :].contiguous()
  556. attention_windows = attention_windows.view(batch_size, height * width, channels)
  557. hidden_states = shortcut + self.drop_path(attention_windows)
  558. layer_output = self.layernorm_after(hidden_states)
  559. layer_output = self.intermediate(layer_output)
  560. layer_output = hidden_states + self.output(layer_output)
  561. layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
  562. return layer_outputs
  563. class SwinStage(GradientCheckpointingLayer):
  564. def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample):
  565. super().__init__()
  566. self.config = config
  567. self.dim = dim
  568. self.blocks = nn.ModuleList(
  569. [
  570. SwinLayer(
  571. config=config,
  572. dim=dim,
  573. input_resolution=input_resolution,
  574. num_heads=num_heads,
  575. drop_path_rate=drop_path[i],
  576. shift_size=0 if (i % 2 == 0) else config.window_size // 2,
  577. )
  578. for i in range(depth)
  579. ]
  580. )
  581. # patch merging layer
  582. if downsample is not None:
  583. self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
  584. else:
  585. self.downsample = None
  586. self.pointing = False
  587. def forward(
  588. self,
  589. hidden_states: torch.Tensor,
  590. input_dimensions: tuple[int, int],
  591. head_mask: Optional[torch.FloatTensor] = None,
  592. output_attentions: Optional[bool] = False,
  593. always_partition: Optional[bool] = False,
  594. ) -> tuple[torch.Tensor]:
  595. height, width = input_dimensions
  596. for i, layer_module in enumerate(self.blocks):
  597. layer_head_mask = head_mask[i] if head_mask is not None else None
  598. layer_outputs = layer_module(
  599. hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
  600. )
  601. hidden_states = layer_outputs[0]
  602. hidden_states_before_downsampling = hidden_states
  603. if self.downsample is not None:
  604. height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
  605. output_dimensions = (height, width, height_downsampled, width_downsampled)
  606. hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
  607. else:
  608. output_dimensions = (height, width, height, width)
  609. stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
  610. if output_attentions:
  611. stage_outputs += layer_outputs[1:]
  612. return stage_outputs
  613. class SwinEncoder(nn.Module):
  614. def __init__(self, config, grid_size):
  615. super().__init__()
  616. self.num_layers = len(config.depths)
  617. self.config = config
  618. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
  619. self.layers = nn.ModuleList(
  620. [
  621. SwinStage(
  622. config=config,
  623. dim=int(config.embed_dim * 2**i_layer),
  624. input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
  625. depth=config.depths[i_layer],
  626. num_heads=config.num_heads[i_layer],
  627. drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
  628. downsample=SwinPatchMerging if (i_layer < self.num_layers - 1) else None,
  629. )
  630. for i_layer in range(self.num_layers)
  631. ]
  632. )
  633. self.gradient_checkpointing = False
  634. def forward(
  635. self,
  636. hidden_states: torch.Tensor,
  637. input_dimensions: tuple[int, int],
  638. head_mask: Optional[torch.FloatTensor] = None,
  639. output_attentions: Optional[bool] = False,
  640. output_hidden_states: Optional[bool] = False,
  641. output_hidden_states_before_downsampling: Optional[bool] = False,
  642. always_partition: Optional[bool] = False,
  643. return_dict: Optional[bool] = True,
  644. ) -> Union[tuple, SwinEncoderOutput]:
  645. all_hidden_states = () if output_hidden_states else None
  646. all_reshaped_hidden_states = () if output_hidden_states else None
  647. all_self_attentions = () if output_attentions else None
  648. if output_hidden_states:
  649. batch_size, _, hidden_size = hidden_states.shape
  650. # rearrange b (h w) c -> b c h w
  651. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  652. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  653. all_hidden_states += (hidden_states,)
  654. all_reshaped_hidden_states += (reshaped_hidden_state,)
  655. for i, layer_module in enumerate(self.layers):
  656. layer_head_mask = head_mask[i] if head_mask is not None else None
  657. layer_outputs = layer_module(
  658. hidden_states, input_dimensions, layer_head_mask, output_attentions, always_partition
  659. )
  660. hidden_states = layer_outputs[0]
  661. hidden_states_before_downsampling = layer_outputs[1]
  662. output_dimensions = layer_outputs[2]
  663. input_dimensions = (output_dimensions[-2], output_dimensions[-1])
  664. if output_hidden_states and output_hidden_states_before_downsampling:
  665. batch_size, _, hidden_size = hidden_states_before_downsampling.shape
  666. # rearrange b (h w) c -> b c h w
  667. # here we use the original (not downsampled) height and width
  668. reshaped_hidden_state = hidden_states_before_downsampling.view(
  669. batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
  670. )
  671. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  672. all_hidden_states += (hidden_states_before_downsampling,)
  673. all_reshaped_hidden_states += (reshaped_hidden_state,)
  674. elif output_hidden_states and not output_hidden_states_before_downsampling:
  675. batch_size, _, hidden_size = hidden_states.shape
  676. # rearrange b (h w) c -> b c h w
  677. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  678. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  679. all_hidden_states += (hidden_states,)
  680. all_reshaped_hidden_states += (reshaped_hidden_state,)
  681. if output_attentions:
  682. all_self_attentions += layer_outputs[3:]
  683. if not return_dict:
  684. return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
  685. return SwinEncoderOutput(
  686. last_hidden_state=hidden_states,
  687. hidden_states=all_hidden_states,
  688. attentions=all_self_attentions,
  689. reshaped_hidden_states=all_reshaped_hidden_states,
  690. )
  691. @auto_docstring
  692. class SwinPreTrainedModel(PreTrainedModel):
  693. config: SwinConfig
  694. base_model_prefix = "swin"
  695. main_input_name = "pixel_values"
  696. supports_gradient_checkpointing = True
  697. _no_split_modules = ["SwinStage"]
  698. def _init_weights(self, module):
  699. """Initialize the weights"""
  700. if isinstance(module, (nn.Linear, nn.Conv2d)):
  701. # Slightly different from the TF version which uses truncated_normal for initialization
  702. # cf https://github.com/pytorch/pytorch/pull/5617
  703. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  704. if module.bias is not None:
  705. module.bias.data.zero_()
  706. elif isinstance(module, nn.LayerNorm):
  707. module.bias.data.zero_()
  708. module.weight.data.fill_(1.0)
  709. elif isinstance(module, SwinEmbeddings):
  710. if module.mask_token is not None:
  711. module.mask_token.data.zero_()
  712. if module.position_embeddings is not None:
  713. module.position_embeddings.data.zero_()
  714. elif isinstance(module, SwinSelfAttention):
  715. module.relative_position_bias_table.data.zero_()
  716. @auto_docstring
  717. class SwinModel(SwinPreTrainedModel):
  718. def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
  719. r"""
  720. add_pooling_layer (`bool`, *optional*, defaults to `True`):
  721. Whether or not to apply pooling layer.
  722. use_mask_token (`bool`, *optional*, defaults to `False`):
  723. Whether or not to create and apply mask tokens in the embedding layer.
  724. """
  725. super().__init__(config)
  726. self.config = config
  727. self.num_layers = len(config.depths)
  728. self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
  729. self.embeddings = SwinEmbeddings(config, use_mask_token=use_mask_token)
  730. self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
  731. self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
  732. self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
  733. # Initialize weights and apply final processing
  734. self.post_init()
  735. def get_input_embeddings(self):
  736. return self.embeddings.patch_embeddings
  737. def _prune_heads(self, heads_to_prune):
  738. """
  739. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  740. class PreTrainedModel
  741. """
  742. for layer, heads in heads_to_prune.items():
  743. self.encoder.layer[layer].attention.prune_heads(heads)
  744. @auto_docstring
  745. def forward(
  746. self,
  747. pixel_values: Optional[torch.FloatTensor] = None,
  748. bool_masked_pos: Optional[torch.BoolTensor] = None,
  749. head_mask: Optional[torch.FloatTensor] = None,
  750. output_attentions: Optional[bool] = None,
  751. output_hidden_states: Optional[bool] = None,
  752. interpolate_pos_encoding: bool = False,
  753. return_dict: Optional[bool] = None,
  754. ) -> Union[tuple, SwinModelOutput]:
  755. r"""
  756. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  757. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  758. """
  759. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  760. output_hidden_states = (
  761. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  762. )
  763. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  764. if pixel_values is None:
  765. raise ValueError("You have to specify pixel_values")
  766. # Prepare head mask if needed
  767. # 1.0 in head_mask indicate we keep the head
  768. # attention_probs has shape bsz x n_heads x N x N
  769. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  770. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  771. head_mask = self.get_head_mask(head_mask, len(self.config.depths))
  772. embedding_output, input_dimensions = self.embeddings(
  773. pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
  774. )
  775. encoder_outputs = self.encoder(
  776. embedding_output,
  777. input_dimensions,
  778. head_mask=head_mask,
  779. output_attentions=output_attentions,
  780. output_hidden_states=output_hidden_states,
  781. return_dict=return_dict,
  782. )
  783. sequence_output = encoder_outputs[0]
  784. sequence_output = self.layernorm(sequence_output)
  785. pooled_output = None
  786. if self.pooler is not None:
  787. pooled_output = self.pooler(sequence_output.transpose(1, 2))
  788. pooled_output = torch.flatten(pooled_output, 1)
  789. if not return_dict:
  790. output = (sequence_output, pooled_output) + encoder_outputs[1:]
  791. return output
  792. return SwinModelOutput(
  793. last_hidden_state=sequence_output,
  794. pooler_output=pooled_output,
  795. hidden_states=encoder_outputs.hidden_states,
  796. attentions=encoder_outputs.attentions,
  797. reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
  798. )
  799. @auto_docstring(
  800. custom_intro="""
  801. Swin Model with a decoder on top for masked image modeling, as proposed in [SimMIM](https://huggingface.co/papers/2111.09886).
  802. <Tip>
  803. Note that we provide a script to pre-train this model on custom data in our [examples
  804. directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
  805. </Tip>
  806. """
  807. )
  808. class SwinForMaskedImageModeling(SwinPreTrainedModel):
  809. def __init__(self, config):
  810. super().__init__(config)
  811. self.swin = SwinModel(config, add_pooling_layer=False, use_mask_token=True)
  812. num_features = int(config.embed_dim * 2 ** (config.num_layers - 1))
  813. self.decoder = nn.Sequential(
  814. nn.Conv2d(
  815. in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1
  816. ),
  817. nn.PixelShuffle(config.encoder_stride),
  818. )
  819. # Initialize weights and apply final processing
  820. self.post_init()
  821. @auto_docstring
  822. def forward(
  823. self,
  824. pixel_values: Optional[torch.FloatTensor] = None,
  825. bool_masked_pos: Optional[torch.BoolTensor] = None,
  826. head_mask: Optional[torch.FloatTensor] = None,
  827. output_attentions: Optional[bool] = None,
  828. output_hidden_states: Optional[bool] = None,
  829. interpolate_pos_encoding: bool = False,
  830. return_dict: Optional[bool] = None,
  831. ) -> Union[tuple, SwinMaskedImageModelingOutput]:
  832. r"""
  833. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  834. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  835. Examples:
  836. ```python
  837. >>> from transformers import AutoImageProcessor, SwinForMaskedImageModeling
  838. >>> import torch
  839. >>> from PIL import Image
  840. >>> import requests
  841. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  842. >>> image = Image.open(requests.get(url, stream=True).raw)
  843. >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swin-base-simmim-window6-192")
  844. >>> model = SwinForMaskedImageModeling.from_pretrained("microsoft/swin-base-simmim-window6-192")
  845. >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
  846. >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
  847. >>> # create random boolean mask of shape (batch_size, num_patches)
  848. >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
  849. >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
  850. >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
  851. >>> list(reconstructed_pixel_values.shape)
  852. [1, 3, 192, 192]
  853. ```"""
  854. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  855. outputs = self.swin(
  856. pixel_values,
  857. bool_masked_pos=bool_masked_pos,
  858. head_mask=head_mask,
  859. output_attentions=output_attentions,
  860. output_hidden_states=output_hidden_states,
  861. interpolate_pos_encoding=interpolate_pos_encoding,
  862. return_dict=return_dict,
  863. )
  864. sequence_output = outputs[0]
  865. # Reshape to (batch_size, num_channels, height, width)
  866. sequence_output = sequence_output.transpose(1, 2)
  867. batch_size, num_channels, sequence_length = sequence_output.shape
  868. height = width = math.floor(sequence_length**0.5)
  869. sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
  870. # Reconstruct pixel values
  871. reconstructed_pixel_values = self.decoder(sequence_output)
  872. masked_im_loss = None
  873. if bool_masked_pos is not None:
  874. size = self.config.image_size // self.config.patch_size
  875. bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
  876. mask = (
  877. bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
  878. .repeat_interleave(self.config.patch_size, 2)
  879. .unsqueeze(1)
  880. .contiguous()
  881. )
  882. reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
  883. masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
  884. if not return_dict:
  885. output = (reconstructed_pixel_values,) + outputs[2:]
  886. return ((masked_im_loss,) + output) if masked_im_loss is not None else output
  887. return SwinMaskedImageModelingOutput(
  888. loss=masked_im_loss,
  889. reconstruction=reconstructed_pixel_values,
  890. hidden_states=outputs.hidden_states,
  891. attentions=outputs.attentions,
  892. reshaped_hidden_states=outputs.reshaped_hidden_states,
  893. )
  894. @auto_docstring(
  895. custom_intro="""
  896. Swin Model transformer with an image classification head on top (a linear layer on top of the final hidden state of
  897. the [CLS] token) e.g. for ImageNet.
  898. <Tip>
  899. Note that it's possible to fine-tune Swin on higher resolution images than the ones it has been trained on, by
  900. setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
  901. position embeddings to the higher resolution.
  902. </Tip>
  903. """
  904. )
  905. class SwinForImageClassification(SwinPreTrainedModel):
  906. def __init__(self, config):
  907. super().__init__(config)
  908. self.num_labels = config.num_labels
  909. self.swin = SwinModel(config)
  910. # Classifier head
  911. self.classifier = (
  912. nn.Linear(self.swin.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
  913. )
  914. # Initialize weights and apply final processing
  915. self.post_init()
  916. @auto_docstring
  917. def forward(
  918. self,
  919. pixel_values: Optional[torch.FloatTensor] = None,
  920. head_mask: Optional[torch.FloatTensor] = None,
  921. labels: Optional[torch.LongTensor] = None,
  922. output_attentions: Optional[bool] = None,
  923. output_hidden_states: Optional[bool] = None,
  924. interpolate_pos_encoding: bool = False,
  925. return_dict: Optional[bool] = None,
  926. ) -> Union[tuple, SwinImageClassifierOutput]:
  927. r"""
  928. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  929. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  930. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  931. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  932. """
  933. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  934. outputs = self.swin(
  935. pixel_values,
  936. head_mask=head_mask,
  937. output_attentions=output_attentions,
  938. output_hidden_states=output_hidden_states,
  939. interpolate_pos_encoding=interpolate_pos_encoding,
  940. return_dict=return_dict,
  941. )
  942. pooled_output = outputs[1]
  943. logits = self.classifier(pooled_output)
  944. loss = None
  945. if labels is not None:
  946. loss = self.loss_function(labels, logits, self.config)
  947. if not return_dict:
  948. output = (logits,) + outputs[2:]
  949. return ((loss,) + output) if loss is not None else output
  950. return SwinImageClassifierOutput(
  951. loss=loss,
  952. logits=logits,
  953. hidden_states=outputs.hidden_states,
  954. attentions=outputs.attentions,
  955. reshaped_hidden_states=outputs.reshaped_hidden_states,
  956. )
  957. @auto_docstring(
  958. custom_intro="""
  959. Swin backbone, to be used with frameworks like DETR and MaskFormer.
  960. """
  961. )
  962. class SwinBackbone(SwinPreTrainedModel, BackboneMixin):
  963. def __init__(self, config: SwinConfig):
  964. super().__init__(config)
  965. super()._init_backbone(config)
  966. self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
  967. self.embeddings = SwinEmbeddings(config)
  968. self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
  969. # Add layer norms to hidden states of out_features
  970. hidden_states_norms = {}
  971. for stage, num_channels in zip(self._out_features, self.channels):
  972. hidden_states_norms[stage] = nn.LayerNorm(num_channels)
  973. self.hidden_states_norms = nn.ModuleDict(hidden_states_norms)
  974. # Initialize weights and apply final processing
  975. self.post_init()
  976. def get_input_embeddings(self):
  977. return self.embeddings.patch_embeddings
  978. def forward(
  979. self,
  980. pixel_values: torch.Tensor,
  981. output_hidden_states: Optional[bool] = None,
  982. output_attentions: Optional[bool] = None,
  983. return_dict: Optional[bool] = None,
  984. ) -> BackboneOutput:
  985. """
  986. Returns:
  987. Examples:
  988. ```python
  989. >>> from transformers import AutoImageProcessor, AutoBackbone
  990. >>> import torch
  991. >>> from PIL import Image
  992. >>> import requests
  993. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  994. >>> image = Image.open(requests.get(url, stream=True).raw)
  995. >>> processor = AutoImageProcessor.from_pretrained("shi-labs/nat-mini-in1k-224")
  996. >>> model = AutoBackbone.from_pretrained(
  997. ... "microsoft/swin-tiny-patch4-window7-224", out_features=["stage1", "stage2", "stage3", "stage4"]
  998. ... )
  999. >>> inputs = processor(image, return_tensors="pt")
  1000. >>> outputs = model(**inputs)
  1001. >>> feature_maps = outputs.feature_maps
  1002. >>> list(feature_maps[-1].shape)
  1003. [1, 768, 7, 7]
  1004. ```"""
  1005. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1006. output_hidden_states = (
  1007. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1008. )
  1009. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1010. embedding_output, input_dimensions = self.embeddings(pixel_values)
  1011. outputs = self.encoder(
  1012. embedding_output,
  1013. input_dimensions,
  1014. head_mask=None,
  1015. output_attentions=output_attentions,
  1016. output_hidden_states=True,
  1017. output_hidden_states_before_downsampling=True,
  1018. always_partition=True,
  1019. return_dict=True,
  1020. )
  1021. hidden_states = outputs.reshaped_hidden_states
  1022. feature_maps = ()
  1023. for stage, hidden_state in zip(self.stage_names, hidden_states):
  1024. if stage in self.out_features:
  1025. batch_size, num_channels, height, width = hidden_state.shape
  1026. hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
  1027. hidden_state = hidden_state.view(batch_size, height * width, num_channels)
  1028. hidden_state = self.hidden_states_norms[stage](hidden_state)
  1029. hidden_state = hidden_state.view(batch_size, height, width, num_channels)
  1030. hidden_state = hidden_state.permute(0, 3, 1, 2).contiguous()
  1031. feature_maps += (hidden_state,)
  1032. if not return_dict:
  1033. output = (feature_maps,)
  1034. if output_hidden_states:
  1035. output += (outputs.hidden_states,)
  1036. return output
  1037. return BackboneOutput(
  1038. feature_maps=feature_maps,
  1039. hidden_states=outputs.hidden_states if output_hidden_states else None,
  1040. attentions=outputs.attentions,
  1041. )
  1042. __all__ = [
  1043. "SwinForImageClassification",
  1044. "SwinForMaskedImageModeling",
  1045. "SwinModel",
  1046. "SwinPreTrainedModel",
  1047. "SwinBackbone",
  1048. ]