modeling_swinv2.py 58 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342
  1. # coding=utf-8
  2. # Copyright 2022 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Swinv2 Transformer model."""
  16. import collections.abc
  17. import math
  18. import warnings
  19. from dataclasses import dataclass
  20. from typing import Optional, Union
  21. import torch
  22. from torch import Tensor, nn
  23. from ...activations import ACT2FN
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import BackboneOutput
  26. from ...modeling_utils import PreTrainedModel
  27. from ...pytorch_utils import find_pruneable_heads_and_indices, meshgrid, prune_linear_layer
  28. from ...utils import ModelOutput, auto_docstring, logging, torch_int
  29. from ...utils.backbone_utils import BackboneMixin
  30. from .configuration_swinv2 import Swinv2Config
  31. logger = logging.get_logger(__name__)
  32. # drop_path, Swinv2PatchEmbeddings, Swinv2PatchMerging and Swinv2DropPath are from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer_v2.py.
  33. @dataclass
  34. @auto_docstring(
  35. custom_intro="""
  36. Swinv2 encoder's outputs, with potential hidden states and attentions.
  37. """
  38. )
  39. # Copied from transformers.models.swin.modeling_swin.SwinEncoderOutput with Swin->Swinv2
  40. class Swinv2EncoderOutput(ModelOutput):
  41. r"""
  42. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  43. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  44. shape `(batch_size, hidden_size, height, width)`.
  45. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  46. include the spatial dimensions.
  47. """
  48. last_hidden_state: Optional[torch.FloatTensor] = None
  49. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  50. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  51. reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  52. @dataclass
  53. @auto_docstring(
  54. custom_intro="""
  55. Swinv2 model's outputs that also contains a pooling of the last hidden states.
  56. """
  57. )
  58. # Copied from transformers.models.swin.modeling_swin.SwinModelOutput with Swin->Swinv2
  59. class Swinv2ModelOutput(ModelOutput):
  60. r"""
  61. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*, returned when `add_pooling_layer=True` is passed):
  62. Average pooling of the last layer hidden-state.
  63. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  64. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  65. shape `(batch_size, hidden_size, height, width)`.
  66. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  67. include the spatial dimensions.
  68. """
  69. last_hidden_state: Optional[torch.FloatTensor] = None
  70. pooler_output: Optional[torch.FloatTensor] = None
  71. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  72. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  73. reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  74. @dataclass
  75. @auto_docstring(
  76. custom_intro="""
  77. Swinv2 masked image model outputs.
  78. """
  79. )
  80. # Copied from transformers.models.swin.modeling_swin.SwinMaskedImageModelingOutput with Swin->Swinv2
  81. class Swinv2MaskedImageModelingOutput(ModelOutput):
  82. r"""
  83. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
  84. Masked image modeling (MLM) loss.
  85. reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
  86. Reconstructed pixel values.
  87. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  88. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  89. shape `(batch_size, hidden_size, height, width)`.
  90. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  91. include the spatial dimensions.
  92. """
  93. loss: Optional[torch.FloatTensor] = None
  94. reconstruction: Optional[torch.FloatTensor] = None
  95. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  96. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  97. reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  98. @property
  99. def logits(self):
  100. warnings.warn(
  101. "logits attribute is deprecated and will be removed in version 5 of Transformers."
  102. " Please use the reconstruction attribute to retrieve the final output instead.",
  103. FutureWarning,
  104. )
  105. return self.reconstruction
  106. @dataclass
  107. @auto_docstring(
  108. custom_intro="""
  109. Swinv2 outputs for image classification.
  110. """
  111. )
  112. # Copied from transformers.models.swin.modeling_swin.SwinImageClassifierOutput with Swin->Swinv2
  113. class Swinv2ImageClassifierOutput(ModelOutput):
  114. r"""
  115. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  116. Classification (or regression if config.num_labels==1) loss.
  117. logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
  118. Classification (or regression if config.num_labels==1) scores (before SoftMax).
  119. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  120. Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
  121. shape `(batch_size, hidden_size, height, width)`.
  122. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to
  123. include the spatial dimensions.
  124. """
  125. loss: Optional[torch.FloatTensor] = None
  126. logits: Optional[torch.FloatTensor] = None
  127. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  128. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  129. reshaped_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  130. # Copied from transformers.models.swin.modeling_swin.window_partition
  131. def window_partition(input_feature, window_size):
  132. """
  133. Partitions the given input into windows.
  134. """
  135. batch_size, height, width, num_channels = input_feature.shape
  136. input_feature = input_feature.view(
  137. batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
  138. )
  139. windows = input_feature.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
  140. return windows
  141. # Copied from transformers.models.swin.modeling_swin.window_reverse
  142. def window_reverse(windows, window_size, height, width):
  143. """
  144. Merges windows to produce higher resolution features.
  145. """
  146. num_channels = windows.shape[-1]
  147. windows = windows.view(-1, height // window_size, width // window_size, window_size, window_size, num_channels)
  148. windows = windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
  149. return windows
  150. # Copied from transformers.models.swin.modeling_swin.drop_path
  151. def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
  152. """
  153. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  154. Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
  155. however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  156. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
  157. layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
  158. argument.
  159. """
  160. if drop_prob == 0.0 or not training:
  161. return input
  162. keep_prob = 1 - drop_prob
  163. shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  164. random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
  165. random_tensor.floor_() # binarize
  166. output = input.div(keep_prob) * random_tensor
  167. return output
  168. # Copied from transformers.models.swin.modeling_swin.SwinDropPath with Swin->Swinv2
  169. class Swinv2DropPath(nn.Module):
  170. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  171. def __init__(self, drop_prob: Optional[float] = None) -> None:
  172. super().__init__()
  173. self.drop_prob = drop_prob
  174. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  175. return drop_path(hidden_states, self.drop_prob, self.training)
  176. def extra_repr(self) -> str:
  177. return f"p={self.drop_prob}"
  178. # Copied from transformers.models.swin.modeling_swin.SwinEmbeddings with Swin->Swinv2
  179. class Swinv2Embeddings(nn.Module):
  180. """
  181. Construct the patch and position embeddings. Optionally, also the mask token.
  182. """
  183. def __init__(self, config, use_mask_token=False):
  184. super().__init__()
  185. self.patch_embeddings = Swinv2PatchEmbeddings(config)
  186. num_patches = self.patch_embeddings.num_patches
  187. self.patch_grid = self.patch_embeddings.grid_size
  188. self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
  189. if config.use_absolute_embeddings:
  190. self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
  191. else:
  192. self.position_embeddings = None
  193. self.norm = nn.LayerNorm(config.embed_dim)
  194. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  195. self.patch_size = config.patch_size
  196. self.config = config
  197. # Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
  198. def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
  199. """
  200. This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
  201. images. This method is also adapted to support torch.jit tracing.
  202. Adapted from:
  203. - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
  204. - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
  205. """
  206. num_patches = embeddings.shape[1] - 1
  207. num_positions = self.position_embeddings.shape[1] - 1
  208. # always interpolate when tracing to ensure the exported model works for dynamic input shapes
  209. if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
  210. return self.position_embeddings
  211. class_pos_embed = self.position_embeddings[:, :1]
  212. patch_pos_embed = self.position_embeddings[:, 1:]
  213. dim = embeddings.shape[-1]
  214. new_height = height // self.patch_size
  215. new_width = width // self.patch_size
  216. sqrt_num_positions = torch_int(num_positions**0.5)
  217. patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
  218. patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
  219. patch_pos_embed = nn.functional.interpolate(
  220. patch_pos_embed,
  221. size=(new_height, new_width),
  222. mode="bicubic",
  223. align_corners=False,
  224. )
  225. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  226. return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
  227. def forward(
  228. self,
  229. pixel_values: Optional[torch.FloatTensor],
  230. bool_masked_pos: Optional[torch.BoolTensor] = None,
  231. interpolate_pos_encoding: bool = False,
  232. ) -> tuple[torch.Tensor]:
  233. _, num_channels, height, width = pixel_values.shape
  234. embeddings, output_dimensions = self.patch_embeddings(pixel_values)
  235. embeddings = self.norm(embeddings)
  236. batch_size, seq_len, _ = embeddings.size()
  237. if bool_masked_pos is not None:
  238. mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
  239. # replace the masked visual tokens by mask_tokens
  240. mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
  241. embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
  242. if self.position_embeddings is not None:
  243. if interpolate_pos_encoding:
  244. embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
  245. else:
  246. embeddings = embeddings + self.position_embeddings
  247. embeddings = self.dropout(embeddings)
  248. return embeddings, output_dimensions
  249. # Copied from transformers.models.swin.modeling_swin.SwinPatchEmbeddings with Swin->Swinv2
  250. class Swinv2PatchEmbeddings(nn.Module):
  251. """
  252. This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
  253. `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
  254. Transformer.
  255. """
  256. def __init__(self, config):
  257. super().__init__()
  258. image_size, patch_size = config.image_size, config.patch_size
  259. num_channels, hidden_size = config.num_channels, config.embed_dim
  260. image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
  261. patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
  262. num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
  263. self.image_size = image_size
  264. self.patch_size = patch_size
  265. self.num_channels = num_channels
  266. self.num_patches = num_patches
  267. self.grid_size = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
  268. self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
  269. def maybe_pad(self, pixel_values, height, width):
  270. if width % self.patch_size[1] != 0:
  271. pad_values = (0, self.patch_size[1] - width % self.patch_size[1])
  272. pixel_values = nn.functional.pad(pixel_values, pad_values)
  273. if height % self.patch_size[0] != 0:
  274. pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0])
  275. pixel_values = nn.functional.pad(pixel_values, pad_values)
  276. return pixel_values
  277. def forward(self, pixel_values: Optional[torch.FloatTensor]) -> tuple[torch.Tensor, tuple[int]]:
  278. _, num_channels, height, width = pixel_values.shape
  279. # pad the input to be divisible by self.patch_size, if needed
  280. pixel_values = self.maybe_pad(pixel_values, height, width)
  281. embeddings = self.projection(pixel_values)
  282. _, _, height, width = embeddings.shape
  283. output_dimensions = (height, width)
  284. embeddings = embeddings.flatten(2).transpose(1, 2)
  285. return embeddings, output_dimensions
  286. class Swinv2PatchMerging(nn.Module):
  287. """
  288. Patch Merging Layer.
  289. Args:
  290. input_resolution (`tuple[int]`):
  291. Resolution of input feature.
  292. dim (`int`):
  293. Number of input channels.
  294. norm_layer (`nn.Module`, *optional*, defaults to `nn.LayerNorm`):
  295. Normalization layer class.
  296. """
  297. def __init__(self, input_resolution: tuple[int], dim: int, norm_layer: nn.Module = nn.LayerNorm) -> None:
  298. super().__init__()
  299. self.input_resolution = input_resolution
  300. self.dim = dim
  301. self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
  302. self.norm = norm_layer(2 * dim)
  303. def maybe_pad(self, input_feature, height, width):
  304. should_pad = (height % 2 == 1) or (width % 2 == 1)
  305. if should_pad:
  306. pad_values = (0, 0, 0, width % 2, 0, height % 2)
  307. input_feature = nn.functional.pad(input_feature, pad_values)
  308. return input_feature
  309. def forward(self, input_feature: torch.Tensor, input_dimensions: tuple[int, int]) -> torch.Tensor:
  310. height, width = input_dimensions
  311. # `dim` is height * width
  312. batch_size, dim, num_channels = input_feature.shape
  313. input_feature = input_feature.view(batch_size, height, width, num_channels)
  314. # pad input to be divisible by width and height, if needed
  315. input_feature = self.maybe_pad(input_feature, height, width)
  316. # [batch_size, height/2, width/2, num_channels]
  317. input_feature_0 = input_feature[:, 0::2, 0::2, :]
  318. # [batch_size, height/2, width/2, num_channels]
  319. input_feature_1 = input_feature[:, 1::2, 0::2, :]
  320. # [batch_size, height/2, width/2, num_channels]
  321. input_feature_2 = input_feature[:, 0::2, 1::2, :]
  322. # [batch_size, height/2, width/2, num_channels]
  323. input_feature_3 = input_feature[:, 1::2, 1::2, :]
  324. # [batch_size, height/2 * width/2, 4*num_channels]
  325. input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1)
  326. input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # [batch_size, height/2 * width/2, 4*C]
  327. input_feature = self.reduction(input_feature)
  328. input_feature = self.norm(input_feature)
  329. return input_feature
  330. class Swinv2SelfAttention(nn.Module):
  331. def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=[0, 0]):
  332. super().__init__()
  333. if dim % num_heads != 0:
  334. raise ValueError(
  335. f"The hidden size ({dim}) is not a multiple of the number of attention heads ({num_heads})"
  336. )
  337. self.num_attention_heads = num_heads
  338. self.attention_head_size = int(dim / num_heads)
  339. self.all_head_size = self.num_attention_heads * self.attention_head_size
  340. self.window_size = (
  341. window_size if isinstance(window_size, collections.abc.Iterable) else (window_size, window_size)
  342. )
  343. self.pretrained_window_size = pretrained_window_size
  344. self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
  345. # mlp to generate continuous relative position bias
  346. self.continuous_position_bias_mlp = nn.Sequential(
  347. nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False)
  348. )
  349. # get relative_coords_table
  350. relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.int64).float()
  351. relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.int64).float()
  352. relative_coords_table = (
  353. torch.stack(meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
  354. .permute(1, 2, 0)
  355. .contiguous()
  356. .unsqueeze(0)
  357. ) # [1, 2*window_height - 1, 2*window_width - 1, 2]
  358. if pretrained_window_size[0] > 0:
  359. relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
  360. relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
  361. elif window_size > 1:
  362. relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
  363. relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
  364. relative_coords_table *= 8 # normalize to -8, 8
  365. relative_coords_table = (
  366. torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8)
  367. )
  368. # set to same dtype as mlp weight
  369. relative_coords_table = relative_coords_table.to(next(self.continuous_position_bias_mlp.parameters()).dtype)
  370. self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)
  371. # get pair-wise relative position index for each token inside the window
  372. coords_h = torch.arange(self.window_size[0])
  373. coords_w = torch.arange(self.window_size[1])
  374. coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
  375. coords_flatten = torch.flatten(coords, 1)
  376. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
  377. relative_coords = relative_coords.permute(1, 2, 0).contiguous()
  378. relative_coords[:, :, 0] += self.window_size[0] - 1
  379. relative_coords[:, :, 1] += self.window_size[1] - 1
  380. relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
  381. relative_position_index = relative_coords.sum(-1)
  382. self.register_buffer("relative_position_index", relative_position_index, persistent=False)
  383. self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  384. self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=False)
  385. self.value = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
  386. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  387. def forward(
  388. self,
  389. hidden_states: torch.Tensor,
  390. attention_mask: Optional[torch.FloatTensor] = None,
  391. head_mask: Optional[torch.FloatTensor] = None,
  392. output_attentions: Optional[bool] = False,
  393. ) -> tuple[torch.Tensor]:
  394. batch_size, dim, num_channels = hidden_states.shape
  395. query_layer = (
  396. self.query(hidden_states)
  397. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  398. .transpose(1, 2)
  399. )
  400. key_layer = (
  401. self.key(hidden_states)
  402. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  403. .transpose(1, 2)
  404. )
  405. value_layer = (
  406. self.value(hidden_states)
  407. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  408. .transpose(1, 2)
  409. )
  410. # cosine attention
  411. attention_scores = nn.functional.normalize(query_layer, dim=-1) @ nn.functional.normalize(
  412. key_layer, dim=-1
  413. ).transpose(-2, -1)
  414. logit_scale = torch.clamp(self.logit_scale, max=math.log(1.0 / 0.01)).exp()
  415. attention_scores = attention_scores * logit_scale
  416. relative_position_bias_table = self.continuous_position_bias_mlp(self.relative_coords_table).view(
  417. -1, self.num_attention_heads
  418. )
  419. # [window_height*window_width,window_height*window_width,num_attention_heads]
  420. relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
  421. self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
  422. )
  423. # [num_attention_heads,window_height*window_width,window_height*window_width]
  424. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  425. relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
  426. attention_scores = attention_scores + relative_position_bias.unsqueeze(0)
  427. if attention_mask is not None:
  428. # Apply the attention mask is (precomputed for all layers in Swinv2Model forward() function)
  429. mask_shape = attention_mask.shape[0]
  430. attention_scores = attention_scores.view(
  431. batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim
  432. ) + attention_mask.unsqueeze(1).unsqueeze(0)
  433. attention_scores = attention_scores + attention_mask.unsqueeze(1).unsqueeze(0)
  434. attention_scores = attention_scores.view(-1, self.num_attention_heads, dim, dim)
  435. # Normalize the attention scores to probabilities.
  436. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  437. # This is actually dropping out entire tokens to attend to, which might
  438. # seem a bit unusual, but is taken from the original Transformer paper.
  439. attention_probs = self.dropout(attention_probs)
  440. # Mask heads if we want to
  441. if head_mask is not None:
  442. attention_probs = attention_probs * head_mask
  443. context_layer = torch.matmul(attention_probs, value_layer)
  444. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  445. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  446. context_layer = context_layer.view(new_context_layer_shape)
  447. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  448. return outputs
  449. # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->Swinv2
  450. class Swinv2SelfOutput(nn.Module):
  451. def __init__(self, config, dim):
  452. super().__init__()
  453. self.dense = nn.Linear(dim, dim)
  454. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  455. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  456. hidden_states = self.dense(hidden_states)
  457. hidden_states = self.dropout(hidden_states)
  458. return hidden_states
  459. class Swinv2Attention(nn.Module):
  460. def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=0):
  461. super().__init__()
  462. self.self = Swinv2SelfAttention(
  463. config=config,
  464. dim=dim,
  465. num_heads=num_heads,
  466. window_size=window_size,
  467. pretrained_window_size=pretrained_window_size
  468. if isinstance(pretrained_window_size, collections.abc.Iterable)
  469. else (pretrained_window_size, pretrained_window_size),
  470. )
  471. self.output = Swinv2SelfOutput(config, dim)
  472. self.pruned_heads = set()
  473. def prune_heads(self, heads):
  474. if len(heads) == 0:
  475. return
  476. heads, index = find_pruneable_heads_and_indices(
  477. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  478. )
  479. # Prune linear layers
  480. self.self.query = prune_linear_layer(self.self.query, index)
  481. self.self.key = prune_linear_layer(self.self.key, index)
  482. self.self.value = prune_linear_layer(self.self.value, index)
  483. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  484. # Update hyper params and store pruned heads
  485. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  486. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  487. self.pruned_heads = self.pruned_heads.union(heads)
  488. def forward(
  489. self,
  490. hidden_states: torch.Tensor,
  491. attention_mask: Optional[torch.FloatTensor] = None,
  492. head_mask: Optional[torch.FloatTensor] = None,
  493. output_attentions: Optional[bool] = False,
  494. ) -> tuple[torch.Tensor]:
  495. self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions)
  496. attention_output = self.output(self_outputs[0], hidden_states)
  497. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  498. return outputs
  499. # Copied from transformers.models.swin.modeling_swin.SwinIntermediate with Swin->Swinv2
  500. class Swinv2Intermediate(nn.Module):
  501. def __init__(self, config, dim):
  502. super().__init__()
  503. self.dense = nn.Linear(dim, int(config.mlp_ratio * dim))
  504. if isinstance(config.hidden_act, str):
  505. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  506. else:
  507. self.intermediate_act_fn = config.hidden_act
  508. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  509. hidden_states = self.dense(hidden_states)
  510. hidden_states = self.intermediate_act_fn(hidden_states)
  511. return hidden_states
  512. # Copied from transformers.models.swin.modeling_swin.SwinOutput with Swin->Swinv2
  513. class Swinv2Output(nn.Module):
  514. def __init__(self, config, dim):
  515. super().__init__()
  516. self.dense = nn.Linear(int(config.mlp_ratio * dim), dim)
  517. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  518. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  519. hidden_states = self.dense(hidden_states)
  520. hidden_states = self.dropout(hidden_states)
  521. return hidden_states
  522. class Swinv2Layer(nn.Module):
  523. def __init__(
  524. self, config, dim, input_resolution, num_heads, drop_path_rate=0.0, shift_size=0, pretrained_window_size=0
  525. ):
  526. super().__init__()
  527. self.input_resolution = input_resolution
  528. window_size, shift_size = self._compute_window_shift(
  529. (config.window_size, config.window_size), (shift_size, shift_size)
  530. )
  531. self.window_size = window_size[0]
  532. self.shift_size = shift_size[0]
  533. self.attention = Swinv2Attention(
  534. config=config,
  535. dim=dim,
  536. num_heads=num_heads,
  537. window_size=self.window_size,
  538. pretrained_window_size=pretrained_window_size
  539. if isinstance(pretrained_window_size, collections.abc.Iterable)
  540. else (pretrained_window_size, pretrained_window_size),
  541. )
  542. self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  543. self.drop_path = Swinv2DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
  544. self.intermediate = Swinv2Intermediate(config, dim)
  545. self.output = Swinv2Output(config, dim)
  546. self.layernorm_after = nn.LayerNorm(dim, eps=config.layer_norm_eps)
  547. def _compute_window_shift(self, target_window_size, target_shift_size) -> tuple[tuple[int, int], tuple[int, int]]:
  548. window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
  549. shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
  550. return window_size, shift_size
  551. def get_attn_mask(self, height, width, dtype):
  552. if self.shift_size > 0:
  553. # calculate attention mask for shifted window multihead self attention
  554. img_mask = torch.zeros((1, height, width, 1), dtype=dtype)
  555. height_slices = (
  556. slice(0, -self.window_size),
  557. slice(-self.window_size, -self.shift_size),
  558. slice(-self.shift_size, None),
  559. )
  560. width_slices = (
  561. slice(0, -self.window_size),
  562. slice(-self.window_size, -self.shift_size),
  563. slice(-self.shift_size, None),
  564. )
  565. count = 0
  566. for height_slice in height_slices:
  567. for width_slice in width_slices:
  568. img_mask[:, height_slice, width_slice, :] = count
  569. count += 1
  570. mask_windows = window_partition(img_mask, self.window_size)
  571. mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
  572. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  573. attn_mask = attn_mask.masked_fill(attn_mask != 0, -100.0).masked_fill(attn_mask == 0, 0.0)
  574. else:
  575. attn_mask = None
  576. return attn_mask
  577. def maybe_pad(self, hidden_states, height, width):
  578. pad_right = (self.window_size - width % self.window_size) % self.window_size
  579. pad_bottom = (self.window_size - height % self.window_size) % self.window_size
  580. pad_values = (0, 0, 0, pad_right, 0, pad_bottom)
  581. hidden_states = nn.functional.pad(hidden_states, pad_values)
  582. return hidden_states, pad_values
  583. def forward(
  584. self,
  585. hidden_states: torch.Tensor,
  586. input_dimensions: tuple[int, int],
  587. head_mask: Optional[torch.FloatTensor] = None,
  588. output_attentions: Optional[bool] = False,
  589. ) -> tuple[torch.Tensor, torch.Tensor]:
  590. height, width = input_dimensions
  591. batch_size, _, channels = hidden_states.size()
  592. shortcut = hidden_states
  593. # pad hidden_states to multiples of window size
  594. hidden_states = hidden_states.view(batch_size, height, width, channels)
  595. hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
  596. _, height_pad, width_pad, _ = hidden_states.shape
  597. # cyclic shift
  598. if self.shift_size > 0:
  599. shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
  600. else:
  601. shifted_hidden_states = hidden_states
  602. # partition windows
  603. hidden_states_windows = window_partition(shifted_hidden_states, self.window_size)
  604. hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels)
  605. attn_mask = self.get_attn_mask(height_pad, width_pad, dtype=hidden_states.dtype)
  606. if attn_mask is not None:
  607. attn_mask = attn_mask.to(hidden_states_windows.device)
  608. attention_outputs = self.attention(
  609. hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions
  610. )
  611. attention_output = attention_outputs[0]
  612. attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels)
  613. shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad)
  614. # reverse cyclic shift
  615. if self.shift_size > 0:
  616. attention_windows = torch.roll(shifted_windows, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
  617. else:
  618. attention_windows = shifted_windows
  619. was_padded = pad_values[3] > 0 or pad_values[5] > 0
  620. if was_padded:
  621. attention_windows = attention_windows[:, :height, :width, :].contiguous()
  622. attention_windows = attention_windows.view(batch_size, height * width, channels)
  623. hidden_states = self.layernorm_before(attention_windows)
  624. hidden_states = shortcut + self.drop_path(hidden_states)
  625. layer_output = self.intermediate(hidden_states)
  626. layer_output = self.output(layer_output)
  627. layer_output = hidden_states + self.drop_path(self.layernorm_after(layer_output))
  628. layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,)
  629. return layer_outputs
  630. class Swinv2Stage(GradientCheckpointingLayer):
  631. def __init__(
  632. self, config, dim, input_resolution, depth, num_heads, drop_path, downsample, pretrained_window_size=0
  633. ):
  634. super().__init__()
  635. self.config = config
  636. self.dim = dim
  637. blocks = []
  638. for i in range(depth):
  639. block = Swinv2Layer(
  640. config=config,
  641. dim=dim,
  642. input_resolution=input_resolution,
  643. num_heads=num_heads,
  644. drop_path_rate=drop_path[i],
  645. shift_size=0 if (i % 2 == 0) else config.window_size // 2,
  646. pretrained_window_size=pretrained_window_size,
  647. )
  648. blocks.append(block)
  649. self.blocks = nn.ModuleList(blocks)
  650. # patch merging layer
  651. if downsample is not None:
  652. self.downsample = downsample(input_resolution, dim=dim, norm_layer=nn.LayerNorm)
  653. else:
  654. self.downsample = None
  655. self.pointing = False
  656. def forward(
  657. self,
  658. hidden_states: torch.Tensor,
  659. input_dimensions: tuple[int, int],
  660. head_mask: Optional[torch.FloatTensor] = None,
  661. output_attentions: Optional[bool] = False,
  662. ) -> tuple[torch.Tensor]:
  663. height, width = input_dimensions
  664. for i, layer_module in enumerate(self.blocks):
  665. layer_head_mask = head_mask[i] if head_mask is not None else None
  666. layer_outputs = layer_module(
  667. hidden_states,
  668. input_dimensions,
  669. layer_head_mask,
  670. output_attentions,
  671. )
  672. hidden_states = layer_outputs[0]
  673. hidden_states_before_downsampling = hidden_states
  674. if self.downsample is not None:
  675. height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2
  676. output_dimensions = (height, width, height_downsampled, width_downsampled)
  677. hidden_states = self.downsample(hidden_states_before_downsampling, input_dimensions)
  678. else:
  679. output_dimensions = (height, width, height, width)
  680. stage_outputs = (hidden_states, hidden_states_before_downsampling, output_dimensions)
  681. if output_attentions:
  682. stage_outputs += layer_outputs[1:]
  683. return stage_outputs
  684. class Swinv2Encoder(nn.Module):
  685. def __init__(self, config, grid_size, pretrained_window_sizes=(0, 0, 0, 0)):
  686. super().__init__()
  687. self.num_layers = len(config.depths)
  688. self.config = config
  689. if self.config.pretrained_window_sizes is not None:
  690. pretrained_window_sizes = config.pretrained_window_sizes
  691. dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths), device="cpu")]
  692. layers = []
  693. for i_layer in range(self.num_layers):
  694. stage = Swinv2Stage(
  695. config=config,
  696. dim=int(config.embed_dim * 2**i_layer),
  697. input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)),
  698. depth=config.depths[i_layer],
  699. num_heads=config.num_heads[i_layer],
  700. drop_path=dpr[sum(config.depths[:i_layer]) : sum(config.depths[: i_layer + 1])],
  701. downsample=Swinv2PatchMerging if (i_layer < self.num_layers - 1) else None,
  702. pretrained_window_size=pretrained_window_sizes[i_layer],
  703. )
  704. layers.append(stage)
  705. self.layers = nn.ModuleList(layers)
  706. self.gradient_checkpointing = False
  707. def forward(
  708. self,
  709. hidden_states: torch.Tensor,
  710. input_dimensions: tuple[int, int],
  711. head_mask: Optional[torch.FloatTensor] = None,
  712. output_attentions: Optional[bool] = False,
  713. output_hidden_states: Optional[bool] = False,
  714. output_hidden_states_before_downsampling: Optional[bool] = False,
  715. return_dict: Optional[bool] = True,
  716. ) -> Union[tuple, Swinv2EncoderOutput]:
  717. all_hidden_states = () if output_hidden_states else None
  718. all_reshaped_hidden_states = () if output_hidden_states else None
  719. all_self_attentions = () if output_attentions else None
  720. if output_hidden_states:
  721. batch_size, _, hidden_size = hidden_states.shape
  722. # rearrange b (h w) c -> b c h w
  723. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  724. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  725. all_hidden_states += (hidden_states,)
  726. all_reshaped_hidden_states += (reshaped_hidden_state,)
  727. for i, layer_module in enumerate(self.layers):
  728. layer_head_mask = head_mask[i] if head_mask is not None else None
  729. layer_outputs = layer_module(
  730. hidden_states,
  731. input_dimensions,
  732. layer_head_mask,
  733. output_attentions,
  734. )
  735. hidden_states = layer_outputs[0]
  736. hidden_states_before_downsampling = layer_outputs[1]
  737. output_dimensions = layer_outputs[2]
  738. input_dimensions = (output_dimensions[-2], output_dimensions[-1])
  739. if output_hidden_states and output_hidden_states_before_downsampling:
  740. batch_size, _, hidden_size = hidden_states_before_downsampling.shape
  741. # rearrange b (h w) c -> b c h w
  742. # here we use the original (not downsampled) height and width
  743. reshaped_hidden_state = hidden_states_before_downsampling.view(
  744. batch_size, *(output_dimensions[0], output_dimensions[1]), hidden_size
  745. )
  746. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  747. all_hidden_states += (hidden_states_before_downsampling,)
  748. all_reshaped_hidden_states += (reshaped_hidden_state,)
  749. elif output_hidden_states and not output_hidden_states_before_downsampling:
  750. batch_size, _, hidden_size = hidden_states.shape
  751. # rearrange b (h w) c -> b c h w
  752. reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size)
  753. reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2)
  754. all_hidden_states += (hidden_states,)
  755. all_reshaped_hidden_states += (reshaped_hidden_state,)
  756. if output_attentions:
  757. all_self_attentions += layer_outputs[3:]
  758. if not return_dict:
  759. return tuple(
  760. v
  761. for v in [hidden_states, all_hidden_states, all_self_attentions, all_reshaped_hidden_states]
  762. if v is not None
  763. )
  764. return Swinv2EncoderOutput(
  765. last_hidden_state=hidden_states,
  766. hidden_states=all_hidden_states,
  767. attentions=all_self_attentions,
  768. reshaped_hidden_states=all_reshaped_hidden_states,
  769. )
  770. @auto_docstring
  771. class Swinv2PreTrainedModel(PreTrainedModel):
  772. config: Swinv2Config
  773. base_model_prefix = "swinv2"
  774. main_input_name = "pixel_values"
  775. supports_gradient_checkpointing = True
  776. _no_split_modules = ["Swinv2Stage"]
  777. def _init_weights(self, module):
  778. """Initialize the weights"""
  779. if isinstance(module, (nn.Linear, nn.Conv2d)):
  780. # Slightly different from the TF version which uses truncated_normal for initialization
  781. # cf https://github.com/pytorch/pytorch/pull/5617
  782. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  783. if module.bias is not None:
  784. module.bias.data.zero_()
  785. elif isinstance(module, nn.LayerNorm):
  786. module.bias.data.zero_()
  787. module.weight.data.fill_(1.0)
  788. elif isinstance(module, Swinv2Embeddings):
  789. if module.mask_token is not None:
  790. module.mask_token.data.zero_()
  791. if module.position_embeddings is not None:
  792. module.position_embeddings.data.zero_()
  793. elif isinstance(module, Swinv2SelfAttention):
  794. module.logit_scale.data.fill_(math.log(10))
  795. @auto_docstring
  796. # Copied from transformers.models.swin.modeling_swin.SwinModel with SWIN->SWINV2,Swin->Swinv2
  797. class Swinv2Model(Swinv2PreTrainedModel):
  798. def __init__(self, config, add_pooling_layer=True, use_mask_token=False):
  799. r"""
  800. add_pooling_layer (`bool`, *optional*, defaults to `True`):
  801. Whether or not to apply pooling layer.
  802. use_mask_token (`bool`, *optional*, defaults to `False`):
  803. Whether or not to create and apply mask tokens in the embedding layer.
  804. """
  805. super().__init__(config)
  806. self.config = config
  807. self.num_layers = len(config.depths)
  808. self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
  809. self.embeddings = Swinv2Embeddings(config, use_mask_token=use_mask_token)
  810. self.encoder = Swinv2Encoder(config, self.embeddings.patch_grid)
  811. self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
  812. self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
  813. # Initialize weights and apply final processing
  814. self.post_init()
  815. def get_input_embeddings(self):
  816. return self.embeddings.patch_embeddings
  817. def _prune_heads(self, heads_to_prune):
  818. """
  819. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  820. class PreTrainedModel
  821. """
  822. for layer, heads in heads_to_prune.items():
  823. self.encoder.layer[layer].attention.prune_heads(heads)
  824. @auto_docstring
  825. def forward(
  826. self,
  827. pixel_values: Optional[torch.FloatTensor] = None,
  828. bool_masked_pos: Optional[torch.BoolTensor] = None,
  829. head_mask: Optional[torch.FloatTensor] = None,
  830. output_attentions: Optional[bool] = None,
  831. output_hidden_states: Optional[bool] = None,
  832. interpolate_pos_encoding: bool = False,
  833. return_dict: Optional[bool] = None,
  834. ) -> Union[tuple, Swinv2ModelOutput]:
  835. r"""
  836. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
  837. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  838. """
  839. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  840. output_hidden_states = (
  841. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  842. )
  843. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  844. if pixel_values is None:
  845. raise ValueError("You have to specify pixel_values")
  846. # Prepare head mask if needed
  847. # 1.0 in head_mask indicate we keep the head
  848. # attention_probs has shape bsz x n_heads x N x N
  849. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  850. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  851. head_mask = self.get_head_mask(head_mask, len(self.config.depths))
  852. embedding_output, input_dimensions = self.embeddings(
  853. pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
  854. )
  855. encoder_outputs = self.encoder(
  856. embedding_output,
  857. input_dimensions,
  858. head_mask=head_mask,
  859. output_attentions=output_attentions,
  860. output_hidden_states=output_hidden_states,
  861. return_dict=return_dict,
  862. )
  863. sequence_output = encoder_outputs[0]
  864. sequence_output = self.layernorm(sequence_output)
  865. pooled_output = None
  866. if self.pooler is not None:
  867. pooled_output = self.pooler(sequence_output.transpose(1, 2))
  868. pooled_output = torch.flatten(pooled_output, 1)
  869. if not return_dict:
  870. output = (sequence_output, pooled_output) + encoder_outputs[1:]
  871. return output
  872. return Swinv2ModelOutput(
  873. last_hidden_state=sequence_output,
  874. pooler_output=pooled_output,
  875. hidden_states=encoder_outputs.hidden_states,
  876. attentions=encoder_outputs.attentions,
  877. reshaped_hidden_states=encoder_outputs.reshaped_hidden_states,
  878. )
  879. @auto_docstring(
  880. custom_intro="""
  881. Swinv2 Model with a decoder on top for masked image modeling, as proposed in
  882. [SimMIM](https://huggingface.co/papers/2111.09886).
  883. <Tip>
  884. Note that we provide a script to pre-train this model on custom data in our [examples
  885. directory](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-pretraining).
  886. </Tip>
  887. """
  888. )
  889. # Copied from transformers.models.swin.modeling_swin.SwinForMaskedImageModeling with swin->swinv2, base-simmim-window6-192->tiny-patch4-window8-256,SWIN->SWINV2,Swin->Swinv2,192->256
  890. class Swinv2ForMaskedImageModeling(Swinv2PreTrainedModel):
  891. def __init__(self, config):
  892. super().__init__(config)
  893. self.swinv2 = Swinv2Model(config, add_pooling_layer=False, use_mask_token=True)
  894. num_features = int(config.embed_dim * 2 ** (config.num_layers - 1))
  895. self.decoder = nn.Sequential(
  896. nn.Conv2d(
  897. in_channels=num_features, out_channels=config.encoder_stride**2 * config.num_channels, kernel_size=1
  898. ),
  899. nn.PixelShuffle(config.encoder_stride),
  900. )
  901. # Initialize weights and apply final processing
  902. self.post_init()
  903. @auto_docstring
  904. def forward(
  905. self,
  906. pixel_values: Optional[torch.FloatTensor] = None,
  907. bool_masked_pos: Optional[torch.BoolTensor] = None,
  908. head_mask: Optional[torch.FloatTensor] = None,
  909. output_attentions: Optional[bool] = None,
  910. output_hidden_states: Optional[bool] = None,
  911. interpolate_pos_encoding: bool = False,
  912. return_dict: Optional[bool] = None,
  913. ) -> Union[tuple, Swinv2MaskedImageModelingOutput]:
  914. r"""
  915. bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
  916. Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
  917. Examples:
  918. ```python
  919. >>> from transformers import AutoImageProcessor, Swinv2ForMaskedImageModeling
  920. >>> import torch
  921. >>> from PIL import Image
  922. >>> import requests
  923. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  924. >>> image = Image.open(requests.get(url, stream=True).raw)
  925. >>> image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
  926. >>> model = Swinv2ForMaskedImageModeling.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
  927. >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2
  928. >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
  929. >>> # create random boolean mask of shape (batch_size, num_patches)
  930. >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
  931. >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
  932. >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction
  933. >>> list(reconstructed_pixel_values.shape)
  934. [1, 3, 256, 256]
  935. ```"""
  936. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  937. outputs = self.swinv2(
  938. pixel_values,
  939. bool_masked_pos=bool_masked_pos,
  940. head_mask=head_mask,
  941. output_attentions=output_attentions,
  942. output_hidden_states=output_hidden_states,
  943. interpolate_pos_encoding=interpolate_pos_encoding,
  944. return_dict=return_dict,
  945. )
  946. sequence_output = outputs[0]
  947. # Reshape to (batch_size, num_channels, height, width)
  948. sequence_output = sequence_output.transpose(1, 2)
  949. batch_size, num_channels, sequence_length = sequence_output.shape
  950. height = width = math.floor(sequence_length**0.5)
  951. sequence_output = sequence_output.reshape(batch_size, num_channels, height, width)
  952. # Reconstruct pixel values
  953. reconstructed_pixel_values = self.decoder(sequence_output)
  954. masked_im_loss = None
  955. if bool_masked_pos is not None:
  956. size = self.config.image_size // self.config.patch_size
  957. bool_masked_pos = bool_masked_pos.reshape(-1, size, size)
  958. mask = (
  959. bool_masked_pos.repeat_interleave(self.config.patch_size, 1)
  960. .repeat_interleave(self.config.patch_size, 2)
  961. .unsqueeze(1)
  962. .contiguous()
  963. )
  964. reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none")
  965. masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels
  966. if not return_dict:
  967. output = (reconstructed_pixel_values,) + outputs[2:]
  968. return ((masked_im_loss,) + output) if masked_im_loss is not None else output
  969. return Swinv2MaskedImageModelingOutput(
  970. loss=masked_im_loss,
  971. reconstruction=reconstructed_pixel_values,
  972. hidden_states=outputs.hidden_states,
  973. attentions=outputs.attentions,
  974. reshaped_hidden_states=outputs.reshaped_hidden_states,
  975. )
  976. @auto_docstring(
  977. custom_intro="""
  978. Swinv2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
  979. of the [CLS] token) e.g. for ImageNet.
  980. <Tip>
  981. Note that it's possible to fine-tune SwinV2 on higher resolution images than the ones it has been trained on, by
  982. setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
  983. position embeddings to the higher resolution.
  984. </Tip>
  985. """
  986. )
  987. # Copied from transformers.models.swin.modeling_swin.SwinForImageClassification with SWIN->SWINV2,Swin->Swinv2,swin->swinv2
  988. class Swinv2ForImageClassification(Swinv2PreTrainedModel):
  989. def __init__(self, config):
  990. super().__init__(config)
  991. self.num_labels = config.num_labels
  992. self.swinv2 = Swinv2Model(config)
  993. # Classifier head
  994. self.classifier = (
  995. nn.Linear(self.swinv2.num_features, config.num_labels) if config.num_labels > 0 else nn.Identity()
  996. )
  997. # Initialize weights and apply final processing
  998. self.post_init()
  999. @auto_docstring
  1000. def forward(
  1001. self,
  1002. pixel_values: Optional[torch.FloatTensor] = None,
  1003. head_mask: Optional[torch.FloatTensor] = None,
  1004. labels: Optional[torch.LongTensor] = None,
  1005. output_attentions: Optional[bool] = None,
  1006. output_hidden_states: Optional[bool] = None,
  1007. interpolate_pos_encoding: bool = False,
  1008. return_dict: Optional[bool] = None,
  1009. ) -> Union[tuple, Swinv2ImageClassifierOutput]:
  1010. r"""
  1011. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1012. Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
  1013. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1014. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1015. """
  1016. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1017. outputs = self.swinv2(
  1018. pixel_values,
  1019. head_mask=head_mask,
  1020. output_attentions=output_attentions,
  1021. output_hidden_states=output_hidden_states,
  1022. interpolate_pos_encoding=interpolate_pos_encoding,
  1023. return_dict=return_dict,
  1024. )
  1025. pooled_output = outputs[1]
  1026. logits = self.classifier(pooled_output)
  1027. loss = None
  1028. if labels is not None:
  1029. loss = self.loss_function(labels, logits, self.config)
  1030. if not return_dict:
  1031. output = (logits,) + outputs[2:]
  1032. return ((loss,) + output) if loss is not None else output
  1033. return Swinv2ImageClassifierOutput(
  1034. loss=loss,
  1035. logits=logits,
  1036. hidden_states=outputs.hidden_states,
  1037. attentions=outputs.attentions,
  1038. reshaped_hidden_states=outputs.reshaped_hidden_states,
  1039. )
  1040. @auto_docstring(
  1041. custom_intro="""
  1042. Swinv2 backbone, to be used with frameworks like DETR and MaskFormer.
  1043. """
  1044. )
  1045. class Swinv2Backbone(Swinv2PreTrainedModel, BackboneMixin):
  1046. def __init__(self, config):
  1047. super().__init__(config)
  1048. super()._init_backbone(config)
  1049. self.num_features = [config.embed_dim] + [int(config.embed_dim * 2**i) for i in range(len(config.depths))]
  1050. self.embeddings = Swinv2Embeddings(config)
  1051. self.encoder = Swinv2Encoder(config, self.embeddings.patch_grid)
  1052. # initialize weights and apply final processing
  1053. self.post_init()
  1054. def get_input_embeddings(self):
  1055. return self.embeddings.patch_embeddings
  1056. @auto_docstring
  1057. def forward(
  1058. self,
  1059. pixel_values: Tensor,
  1060. output_attentions: Optional[bool] = None,
  1061. output_hidden_states: Optional[bool] = None,
  1062. return_dict: Optional[bool] = None,
  1063. ) -> BackboneOutput:
  1064. r"""
  1065. Examples:
  1066. ```python
  1067. >>> from transformers import AutoImageProcessor, AutoBackbone
  1068. >>> import torch
  1069. >>> from PIL import Image
  1070. >>> import requests
  1071. >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
  1072. >>> image = Image.open(requests.get(url, stream=True).raw)
  1073. >>> processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window8-256")
  1074. >>> model = AutoBackbone.from_pretrained(
  1075. ... "microsoft/swinv2-tiny-patch4-window8-256", out_features=["stage1", "stage2", "stage3", "stage4"]
  1076. ... )
  1077. >>> inputs = processor(image, return_tensors="pt")
  1078. >>> outputs = model(**inputs)
  1079. >>> feature_maps = outputs.feature_maps
  1080. >>> list(feature_maps[-1].shape)
  1081. [1, 2048, 7, 7]
  1082. ```"""
  1083. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1084. output_hidden_states = (
  1085. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1086. )
  1087. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1088. embedding_output, input_dimensions = self.embeddings(pixel_values)
  1089. outputs = self.encoder(
  1090. embedding_output,
  1091. input_dimensions,
  1092. head_mask=None,
  1093. output_attentions=output_attentions,
  1094. output_hidden_states=True,
  1095. output_hidden_states_before_downsampling=True,
  1096. return_dict=return_dict,
  1097. )
  1098. hidden_states = outputs.reshaped_hidden_states if return_dict else outputs[-1]
  1099. feature_maps = ()
  1100. for stage, hidden_state in zip(self.stage_names, hidden_states):
  1101. if stage in self.out_features:
  1102. feature_maps += (hidden_state,)
  1103. if not return_dict:
  1104. output = (feature_maps,)
  1105. if output_hidden_states:
  1106. output += (outputs[1],)
  1107. if output_attentions:
  1108. output += (outputs[2],)
  1109. return output
  1110. return BackboneOutput(
  1111. feature_maps=feature_maps,
  1112. hidden_states=outputs.hidden_states if output_hidden_states else None,
  1113. attentions=outputs.attentions,
  1114. )
  1115. __all__ = [
  1116. "Swinv2ForImageClassification",
  1117. "Swinv2ForMaskedImageModeling",
  1118. "Swinv2Model",
  1119. "Swinv2PreTrainedModel",
  1120. "Swinv2Backbone",
  1121. ]