swin_transformer.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215
  1. """ Swin Transformer
  2. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`
  3. - https://arxiv.org/pdf/2103.14030
  4. Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below
  5. S3 (AutoFormerV2, https://arxiv.org/abs/2111.14725) Swin weights from
  6. - https://github.com/microsoft/Cream/tree/main/AutoFormerV2
  7. Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
  8. """
  9. # --------------------------------------------------------
  10. # Swin Transformer
  11. # Copyright (c) 2021 Microsoft
  12. # Licensed under The MIT License [see LICENSE for details]
  13. # Written by Ze Liu
  14. # --------------------------------------------------------
  15. import logging
  16. import math
  17. from typing import Any, Dict, Callable, List, Optional, Set, Tuple, Union, Type
  18. import torch
  19. import torch.nn as nn
  20. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  21. from timm.layers import PatchEmbed, Mlp, DropPath, calculate_drop_path_rates, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \
  22. use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed, ndgrid
  23. from ._builder import build_model_with_cfg
  24. from ._features import feature_take_indices
  25. from ._features_fx import register_notrace_function
  26. from ._manipulate import checkpoint_seq, named_apply
  27. from ._registry import generate_default_cfgs, register_model, register_model_deprecations
  28. from .vision_transformer import get_init_weights_vit
  29. __all__ = ['SwinTransformer'] # model_registry will add each entrypoint fn to this
  30. _logger = logging.getLogger(__name__)
  31. _int_or_tuple_2_t = Union[int, Tuple[int, int]]
  32. def window_partition(
  33. x: torch.Tensor,
  34. window_size: Tuple[int, int],
  35. ) -> torch.Tensor:
  36. """Partition into non-overlapping windows.
  37. Args:
  38. x: Input tokens with shape [B, H, W, C].
  39. window_size: Window size.
  40. Returns:
  41. Windows after partition with shape [B * num_windows, window_size, window_size, C].
  42. """
  43. B, H, W, C = x.shape
  44. x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
  45. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
  46. return windows
  47. @register_notrace_function # reason: int argument is a Proxy
  48. def window_reverse(windows: torch.Tensor, window_size: Tuple[int, int], H: int, W: int) -> torch.Tensor:
  49. """Reverse window partition.
  50. Args:
  51. windows: Windows with shape (num_windows*B, window_size, window_size, C).
  52. window_size: Window size.
  53. H: Height of image.
  54. W: Width of image.
  55. Returns:
  56. Tensor with shape (B, H, W, C).
  57. """
  58. C = windows.shape[-1]
  59. x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
  60. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
  61. return x
  62. def get_relative_position_index(win_h: int, win_w: int, device=None) -> torch.Tensor:
  63. """Get pair-wise relative position index for each token inside the window.
  64. Args:
  65. win_h: Window height.
  66. win_w: Window width.
  67. Returns:
  68. Relative position index tensor.
  69. """
  70. # get pair-wise relative position index for each token inside the window
  71. coords = torch.stack(ndgrid(
  72. torch.arange(win_h, device=device, dtype=torch.long),
  73. torch.arange(win_w, device=device, dtype=torch.long),
  74. )) # 2, Wh, Ww
  75. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  76. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  77. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  78. relative_coords[:, :, 0] += win_h - 1 # shift to start from 0
  79. relative_coords[:, :, 1] += win_w - 1
  80. relative_coords[:, :, 0] *= 2 * win_w - 1
  81. return relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  82. class WindowAttention(nn.Module):
  83. """Window based multi-head self attention (W-MSA) module with relative position bias.
  84. Supports both shifted and non-shifted windows.
  85. """
  86. fused_attn: torch.jit.Final[bool]
  87. def __init__(
  88. self,
  89. dim: int,
  90. num_heads: int,
  91. head_dim: Optional[int] = None,
  92. window_size: _int_or_tuple_2_t = 7,
  93. qkv_bias: bool = True,
  94. attn_drop: float = 0.,
  95. proj_drop: float = 0.,
  96. device=None,
  97. dtype=None,
  98. ):
  99. """
  100. Args:
  101. dim: Number of input channels.
  102. num_heads: Number of attention heads.
  103. head_dim: Number of channels per head (dim // num_heads if not set)
  104. window_size: The height and width of the window.
  105. qkv_bias: If True, add a learnable bias to query, key, value.
  106. attn_drop: Dropout ratio of attention weight.
  107. proj_drop: Dropout ratio of output.
  108. """
  109. dd = {'device': device, 'dtype': dtype}
  110. super().__init__()
  111. self.dim = dim
  112. self.window_size = to_2tuple(window_size) # Wh, Ww
  113. win_h, win_w = self.window_size
  114. self.window_area = win_h * win_w
  115. self.num_heads = num_heads
  116. head_dim = head_dim or dim // num_heads
  117. attn_dim = head_dim * num_heads
  118. self.scale = head_dim ** -0.5
  119. self.fused_attn = use_fused_attn(experimental=True) # NOTE not tested for prime-time yet
  120. # define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH
  121. self.relative_position_bias_table = nn.Parameter(
  122. torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads, **dd))
  123. # get pair-wise relative position index for each token inside the window
  124. self.register_buffer(
  125. "relative_position_index",
  126. get_relative_position_index(win_h, win_w, device=device),
  127. persistent=False,
  128. )
  129. self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias, **dd)
  130. self.attn_drop = nn.Dropout(attn_drop)
  131. self.proj = nn.Linear(attn_dim, dim, **dd)
  132. self.proj_drop = nn.Dropout(proj_drop)
  133. trunc_normal_(self.relative_position_bias_table, std=.02)
  134. self.softmax = nn.Softmax(dim=-1)
  135. def set_window_size(self, window_size: Tuple[int, int]) -> None:
  136. """Update window size & interpolate position embeddings
  137. Args:
  138. window_size (int): New window size
  139. """
  140. window_size = to_2tuple(window_size)
  141. if window_size == self.window_size:
  142. return
  143. self.window_size = window_size
  144. win_h, win_w = self.window_size
  145. self.window_area = win_h * win_w
  146. with torch.no_grad():
  147. new_bias_shape = (2 * win_h - 1) * (2 * win_w - 1), self.num_heads
  148. self.relative_position_bias_table = nn.Parameter(
  149. resize_rel_pos_bias_table(
  150. self.relative_position_bias_table,
  151. new_window_size=self.window_size,
  152. new_bias_shape=new_bias_shape,
  153. ))
  154. self.register_buffer(
  155. "relative_position_index",
  156. get_relative_position_index(win_h, win_w, device=self.relative_position_bias_table.device),
  157. persistent=False,
  158. )
  159. def _get_rel_pos_bias(self) -> torch.Tensor:
  160. relative_position_bias = self.relative_position_bias_table[
  161. self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH
  162. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  163. return relative_position_bias.unsqueeze(0)
  164. def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  165. """Forward pass.
  166. Args:
  167. x: Input features with shape of (num_windows*B, N, C).
  168. mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None.
  169. Returns:
  170. Output features with shape of (num_windows*B, N, C).
  171. """
  172. B_, N, C = x.shape
  173. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  174. q, k, v = qkv.unbind(0)
  175. if self.fused_attn:
  176. attn_mask = self._get_rel_pos_bias()
  177. if mask is not None:
  178. num_win = mask.shape[0]
  179. mask = mask.view(1, num_win, 1, N, N).expand(B_ // num_win, -1, self.num_heads, -1, -1)
  180. attn_mask = attn_mask + mask.reshape(-1, self.num_heads, N, N)
  181. x = torch.nn.functional.scaled_dot_product_attention(
  182. q, k, v,
  183. attn_mask=attn_mask,
  184. dropout_p=self.attn_drop.p if self.training else 0.,
  185. )
  186. else:
  187. q = q * self.scale
  188. attn = q @ k.transpose(-2, -1)
  189. attn = attn + self._get_rel_pos_bias()
  190. if mask is not None:
  191. num_win = mask.shape[0]
  192. attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
  193. attn = attn.view(-1, self.num_heads, N, N)
  194. attn = self.softmax(attn)
  195. attn = self.attn_drop(attn)
  196. x = attn @ v
  197. x = x.transpose(1, 2).reshape(B_, N, -1)
  198. x = self.proj(x)
  199. x = self.proj_drop(x)
  200. return x
  201. class SwinTransformerBlock(nn.Module):
  202. """Swin Transformer Block.
  203. A transformer block with window-based self-attention and shifted windows.
  204. """
  205. def __init__(
  206. self,
  207. dim: int,
  208. input_resolution: _int_or_tuple_2_t,
  209. num_heads: int = 4,
  210. head_dim: Optional[int] = None,
  211. window_size: _int_or_tuple_2_t = 7,
  212. shift_size: int = 0,
  213. always_partition: bool = False,
  214. dynamic_mask: bool = False,
  215. mlp_ratio: float = 4.,
  216. qkv_bias: bool = True,
  217. proj_drop: float = 0.,
  218. attn_drop: float = 0.,
  219. drop_path: float = 0.,
  220. act_layer: Type[nn.Module] = nn.GELU,
  221. norm_layer: Type[nn.Module] = nn.LayerNorm,
  222. device=None,
  223. dtype=None,
  224. ):
  225. """
  226. Args:
  227. dim: Number of input channels.
  228. input_resolution: Input resolution.
  229. window_size: Window size.
  230. num_heads: Number of attention heads.
  231. head_dim: Enforce the number of channels per head
  232. shift_size: Shift size for SW-MSA.
  233. always_partition: Always partition into full windows and shift
  234. mlp_ratio: Ratio of mlp hidden dim to embedding dim.
  235. qkv_bias: If True, add a learnable bias to query, key, value.
  236. proj_drop: Dropout rate.
  237. attn_drop: Attention dropout rate.
  238. drop_path: Stochastic depth rate.
  239. act_layer: Activation layer.
  240. norm_layer: Normalization layer.
  241. """
  242. dd = {'device': device, 'dtype': dtype}
  243. super().__init__()
  244. self.dim = dim
  245. self.input_resolution = input_resolution
  246. self.target_shift_size = to_2tuple(shift_size) # store for later resize
  247. self.always_partition = always_partition
  248. self.dynamic_mask = dynamic_mask
  249. self.window_size, self.shift_size = self._calc_window_shift(window_size, shift_size)
  250. self.window_area = self.window_size[0] * self.window_size[1]
  251. self.mlp_ratio = mlp_ratio
  252. self.norm1 = norm_layer(dim, **dd)
  253. self.attn = WindowAttention(
  254. dim,
  255. num_heads=num_heads,
  256. head_dim=head_dim,
  257. window_size=self.window_size,
  258. qkv_bias=qkv_bias,
  259. attn_drop=attn_drop,
  260. proj_drop=proj_drop,
  261. **dd,
  262. )
  263. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  264. self.norm2 = norm_layer(dim, **dd)
  265. self.mlp = Mlp(
  266. in_features=dim,
  267. hidden_features=int(dim * mlp_ratio),
  268. act_layer=act_layer,
  269. drop=proj_drop,
  270. **dd,
  271. )
  272. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  273. self.register_buffer(
  274. "attn_mask",
  275. None if self.dynamic_mask else self.get_attn_mask(**dd),
  276. persistent=False,
  277. )
  278. def get_attn_mask(
  279. self,
  280. x: Optional[torch.Tensor] = None,
  281. device: Optional[torch.device] = None,
  282. dtype: Optional[torch.dtype] = None,
  283. ) -> Optional[torch.Tensor]:
  284. if any(self.shift_size):
  285. # calculate attention mask for SW-MSA
  286. if x is not None:
  287. H, W = x.shape[1], x.shape[2]
  288. device = x.device
  289. dtype = x.dtype
  290. else:
  291. H, W = self.input_resolution
  292. device = device
  293. dtype = dtype
  294. H = math.ceil(H / self.window_size[0]) * self.window_size[0]
  295. W = math.ceil(W / self.window_size[1]) * self.window_size[1]
  296. img_mask = torch.zeros((1, H, W, 1), dtype=dtype, device=device) # 1 H W 1
  297. cnt = 0
  298. for h in (
  299. (0, -self.window_size[0]),
  300. (-self.window_size[0], -self.shift_size[0]),
  301. (-self.shift_size[0], None),
  302. ):
  303. for w in (
  304. (0, -self.window_size[1]),
  305. (-self.window_size[1], -self.shift_size[1]),
  306. (-self.shift_size[1], None),
  307. ):
  308. img_mask[:, h[0]:h[1], w[0]:w[1], :] = cnt
  309. cnt += 1
  310. mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
  311. mask_windows = mask_windows.view(-1, self.window_area)
  312. attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
  313. attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
  314. else:
  315. attn_mask = None
  316. return attn_mask
  317. def _calc_window_shift(
  318. self,
  319. target_window_size: Union[int, Tuple[int, int]],
  320. target_shift_size: Optional[Union[int, Tuple[int, int]]] = None,
  321. ) -> Tuple[Tuple[int, int], Tuple[int, int]]:
  322. target_window_size = to_2tuple(target_window_size)
  323. if target_shift_size is None:
  324. # if passed value is None, recalculate from default window_size // 2 if it was previously non-zero
  325. target_shift_size = self.target_shift_size
  326. if any(target_shift_size):
  327. target_shift_size = (target_window_size[0] // 2, target_window_size[1] // 2)
  328. else:
  329. target_shift_size = to_2tuple(target_shift_size)
  330. if self.always_partition:
  331. return target_window_size, target_shift_size
  332. window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
  333. shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
  334. return tuple(window_size), tuple(shift_size)
  335. def set_input_size(
  336. self,
  337. feat_size: Tuple[int, int],
  338. window_size: Tuple[int, int],
  339. always_partition: Optional[bool] = None,
  340. ):
  341. """
  342. Args:
  343. feat_size: New input resolution
  344. window_size: New window size
  345. always_partition: Change always_partition attribute if not None
  346. """
  347. self.input_resolution = feat_size
  348. if always_partition is not None:
  349. self.always_partition = always_partition
  350. self.window_size, self.shift_size = self._calc_window_shift(window_size)
  351. self.window_area = self.window_size[0] * self.window_size[1]
  352. self.attn.set_window_size(self.window_size)
  353. device = self.attn_mask.device if self.attn_mask is not None else None
  354. dtype = self.attn_mask.dtype if self.attn_mask is not None else None
  355. self.register_buffer(
  356. "attn_mask",
  357. None if self.dynamic_mask else self.get_attn_mask(device=device, dtype=dtype),
  358. persistent=False,
  359. )
  360. def _attn(self, x):
  361. B, H, W, C = x.shape
  362. # cyclic shift
  363. has_shift = any(self.shift_size)
  364. if has_shift:
  365. shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
  366. else:
  367. shifted_x = x
  368. # pad for resolution not divisible by window size
  369. pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
  370. pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
  371. shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h))
  372. _, Hp, Wp, _ = shifted_x.shape
  373. # partition windows
  374. x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
  375. x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C
  376. # W-MSA/SW-MSA
  377. if getattr(self, 'dynamic_mask', False):
  378. attn_mask = self.get_attn_mask(shifted_x)
  379. else:
  380. attn_mask = self.attn_mask
  381. attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
  382. # merge windows
  383. attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
  384. shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
  385. shifted_x = shifted_x[:, :H, :W, :].contiguous()
  386. # reverse cyclic shift
  387. if has_shift:
  388. x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
  389. else:
  390. x = shifted_x
  391. return x
  392. def forward(self, x: torch.Tensor) -> torch.Tensor:
  393. """Forward pass.
  394. Args:
  395. x: Input features with shape (B, H, W, C).
  396. Returns:
  397. Output features with shape (B, H, W, C).
  398. """
  399. B, H, W, C = x.shape
  400. x = x + self.drop_path1(self._attn(self.norm1(x)))
  401. x = x.reshape(B, -1, C)
  402. x = x + self.drop_path2(self.mlp(self.norm2(x)))
  403. x = x.reshape(B, H, W, C)
  404. return x
  405. class PatchMerging(nn.Module):
  406. """Patch Merging Layer.
  407. Downsample features by merging 2x2 neighboring patches.
  408. """
  409. def __init__(
  410. self,
  411. dim: int,
  412. out_dim: Optional[int] = None,
  413. norm_layer: Type[nn.Module] = nn.LayerNorm,
  414. device=None,
  415. dtype=None,
  416. ):
  417. """
  418. Args:
  419. dim: Number of input channels.
  420. out_dim: Number of output channels (or 2 * dim if None)
  421. norm_layer: Normalization layer.
  422. """
  423. dd = {'device': device, 'dtype': dtype}
  424. super().__init__()
  425. self.dim = dim
  426. self.out_dim = out_dim or 2 * dim
  427. self.norm = norm_layer(4 * dim, **dd)
  428. self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False, **dd)
  429. def forward(self, x: torch.Tensor) -> torch.Tensor:
  430. """Forward pass.
  431. Args:
  432. x: Input features with shape (B, H, W, C).
  433. Returns:
  434. Output features with shape (B, H//2, W//2, out_dim).
  435. """
  436. B, H, W, C = x.shape
  437. pad_values = (0, 0, 0, W % 2, 0, H % 2)
  438. x = nn.functional.pad(x, pad_values)
  439. _, H, W, _ = x.shape
  440. x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
  441. x = self.norm(x)
  442. x = self.reduction(x)
  443. return x
  444. class SwinTransformerStage(nn.Module):
  445. """A basic Swin Transformer layer for one stage.
  446. Contains multiple Swin Transformer blocks and optional downsampling.
  447. """
  448. def __init__(
  449. self,
  450. dim: int,
  451. out_dim: int,
  452. input_resolution: Tuple[int, int],
  453. depth: int,
  454. downsample: bool = True,
  455. num_heads: int = 4,
  456. head_dim: Optional[int] = None,
  457. window_size: _int_or_tuple_2_t = 7,
  458. always_partition: bool = False,
  459. dynamic_mask: bool = False,
  460. mlp_ratio: float = 4.,
  461. qkv_bias: bool = True,
  462. proj_drop: float = 0.,
  463. attn_drop: float = 0.,
  464. drop_path: Union[List[float], float] = 0.,
  465. norm_layer: Type[nn.Module] = nn.LayerNorm,
  466. device=None,
  467. dtype=None,
  468. ):
  469. """
  470. Args:
  471. dim: Number of input channels.
  472. out_dim: Number of output channels.
  473. input_resolution: Input resolution.
  474. depth: Number of blocks.
  475. downsample: Downsample layer at the end of the layer.
  476. num_heads: Number of attention heads.
  477. head_dim: Channels per head (dim // num_heads if not set)
  478. window_size: Local window size.
  479. mlp_ratio: Ratio of mlp hidden dim to embedding dim.
  480. qkv_bias: If True, add a learnable bias to query, key, value.
  481. proj_drop: Projection dropout rate.
  482. attn_drop: Attention dropout rate.
  483. drop_path: Stochastic depth rate.
  484. norm_layer: Normalization layer.
  485. """
  486. dd = {'device': device, 'dtype': dtype}
  487. super().__init__()
  488. self.dim = dim
  489. self.input_resolution = input_resolution
  490. self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution
  491. self.depth = depth
  492. self.grad_checkpointing = False
  493. window_size = to_2tuple(window_size)
  494. shift_size = tuple([w // 2 for w in window_size])
  495. # patch merging layer
  496. if downsample:
  497. self.downsample = PatchMerging(
  498. dim=dim,
  499. out_dim=out_dim,
  500. norm_layer=norm_layer,
  501. **dd,
  502. )
  503. else:
  504. assert dim == out_dim
  505. self.downsample = nn.Identity()
  506. # build blocks
  507. self.blocks = nn.Sequential(*[
  508. SwinTransformerBlock(
  509. dim=out_dim,
  510. input_resolution=self.output_resolution,
  511. num_heads=num_heads,
  512. head_dim=head_dim,
  513. window_size=window_size,
  514. shift_size=0 if (i % 2 == 0) else shift_size,
  515. always_partition=always_partition,
  516. dynamic_mask=dynamic_mask,
  517. mlp_ratio=mlp_ratio,
  518. qkv_bias=qkv_bias,
  519. proj_drop=proj_drop,
  520. attn_drop=attn_drop,
  521. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  522. norm_layer=norm_layer,
  523. **dd,
  524. )
  525. for i in range(depth)])
  526. def set_input_size(
  527. self,
  528. feat_size: Tuple[int, int],
  529. window_size: int,
  530. always_partition: Optional[bool] = None,
  531. ):
  532. """ Updates the resolution, window size and so the pair-wise relative positions.
  533. Args:
  534. feat_size: New input (feature) resolution
  535. window_size: New window size
  536. always_partition: Always partition / shift the window
  537. """
  538. self.input_resolution = feat_size
  539. if isinstance(self.downsample, nn.Identity):
  540. self.output_resolution = feat_size
  541. else:
  542. self.output_resolution = tuple(i // 2 for i in feat_size)
  543. for block in self.blocks:
  544. block.set_input_size(
  545. feat_size=self.output_resolution,
  546. window_size=window_size,
  547. always_partition=always_partition,
  548. )
  549. def forward(self, x: torch.Tensor) -> torch.Tensor:
  550. """Forward pass.
  551. Args:
  552. x: Input features.
  553. Returns:
  554. Output features.
  555. """
  556. x = self.downsample(x)
  557. if self.grad_checkpointing and not torch.jit.is_scripting():
  558. x = checkpoint_seq(self.blocks, x)
  559. else:
  560. x = self.blocks(x)
  561. return x
  562. class SwinTransformer(nn.Module):
  563. """Swin Transformer.
  564. A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
  565. https://arxiv.org/pdf/2103.14030
  566. """
  567. def __init__(
  568. self,
  569. img_size: _int_or_tuple_2_t = 224,
  570. patch_size: int = 4,
  571. in_chans: int = 3,
  572. num_classes: int = 1000,
  573. global_pool: str = 'avg',
  574. embed_dim: int = 96,
  575. depths: Tuple[int, ...] = (2, 2, 6, 2),
  576. num_heads: Tuple[int, ...] = (3, 6, 12, 24),
  577. head_dim: Optional[int] = None,
  578. window_size: _int_or_tuple_2_t = 7,
  579. always_partition: bool = False,
  580. strict_img_size: bool = True,
  581. mlp_ratio: float = 4.,
  582. qkv_bias: bool = True,
  583. drop_rate: float = 0.,
  584. proj_drop_rate: float = 0.,
  585. attn_drop_rate: float = 0.,
  586. drop_path_rate: float = 0.1,
  587. embed_layer: Type[nn.Module] = PatchEmbed,
  588. norm_layer: Union[str, Type[nn.Module]] = nn.LayerNorm,
  589. weight_init: str = '',
  590. device=None,
  591. dtype=None,
  592. **kwargs,
  593. ):
  594. """
  595. Args:
  596. img_size: Input image size.
  597. patch_size: Patch size.
  598. in_chans: Number of input image channels.
  599. num_classes: Number of classes for classification head.
  600. embed_dim: Patch embedding dimension.
  601. depths: Depth of each Swin Transformer layer.
  602. num_heads: Number of attention heads in different layers.
  603. head_dim: Dimension of self-attention heads.
  604. window_size: Window size.
  605. mlp_ratio: Ratio of mlp hidden dim to embedding dim.
  606. qkv_bias: If True, add a learnable bias to query, key, value.
  607. drop_rate: Dropout rate.
  608. attn_drop_rate (float): Attention dropout rate.
  609. drop_path_rate (float): Stochastic depth rate.
  610. embed_layer: Patch embedding layer.
  611. norm_layer (nn.Module): Normalization layer.
  612. """
  613. super().__init__()
  614. dd = {'device': device, 'dtype': dtype}
  615. assert global_pool in ('', 'avg')
  616. self.num_classes = num_classes
  617. self.global_pool = global_pool
  618. self.output_fmt = 'NHWC'
  619. self.num_layers = len(depths)
  620. self.embed_dim = embed_dim
  621. self.num_features = self.head_hidden_size = int(embed_dim * 2 ** (self.num_layers - 1))
  622. self.feature_info = []
  623. if not isinstance(embed_dim, (tuple, list)):
  624. embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
  625. # split image into non-overlapping patches
  626. self.patch_embed = embed_layer(
  627. img_size=img_size,
  628. patch_size=patch_size,
  629. in_chans=in_chans,
  630. embed_dim=embed_dim[0],
  631. norm_layer=norm_layer,
  632. strict_img_size=strict_img_size,
  633. output_fmt='NHWC',
  634. **dd,
  635. )
  636. patch_grid = self.patch_embed.grid_size
  637. # build layers
  638. head_dim = to_ntuple(self.num_layers)(head_dim)
  639. if not isinstance(window_size, (list, tuple)):
  640. window_size = to_ntuple(self.num_layers)(window_size)
  641. elif len(window_size) == 2:
  642. window_size = (window_size,) * self.num_layers
  643. assert len(window_size) == self.num_layers
  644. mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio)
  645. dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  646. layers = []
  647. in_dim = embed_dim[0]
  648. scale = 1
  649. for i in range(self.num_layers):
  650. out_dim = embed_dim[i]
  651. layers += [SwinTransformerStage(
  652. dim=in_dim,
  653. out_dim=out_dim,
  654. input_resolution=(
  655. patch_grid[0] // scale,
  656. patch_grid[1] // scale
  657. ),
  658. depth=depths[i],
  659. downsample=i > 0,
  660. num_heads=num_heads[i],
  661. head_dim=head_dim[i],
  662. window_size=window_size[i],
  663. always_partition=always_partition,
  664. dynamic_mask=not strict_img_size,
  665. mlp_ratio=mlp_ratio[i],
  666. qkv_bias=qkv_bias,
  667. proj_drop=proj_drop_rate,
  668. attn_drop=attn_drop_rate,
  669. drop_path=dpr[i],
  670. norm_layer=norm_layer,
  671. **dd,
  672. )]
  673. in_dim = out_dim
  674. if i > 0:
  675. scale *= 2
  676. self.feature_info += [dict(num_chs=out_dim, reduction=patch_size * scale, module=f'layers.{i}')]
  677. self.layers = nn.Sequential(*layers)
  678. self.norm = norm_layer(self.num_features, **dd)
  679. self.head = ClassifierHead(
  680. self.num_features,
  681. num_classes,
  682. pool_type=global_pool,
  683. drop_rate=drop_rate,
  684. input_fmt=self.output_fmt,
  685. **dd,
  686. )
  687. if weight_init != 'skip':
  688. self.init_weights(weight_init)
  689. @torch.jit.ignore
  690. def init_weights(self, mode: str = '') -> None:
  691. """Initialize model weights.
  692. Args:
  693. mode: Weight initialization mode ('jax', 'jax_nlhb', 'moco', or '').
  694. """
  695. assert mode in ('jax', 'jax_nlhb', 'moco', '')
  696. head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
  697. named_apply(get_init_weights_vit(mode, head_bias=head_bias), self)
  698. @torch.jit.ignore
  699. def no_weight_decay(self) -> Set[str]:
  700. """Parameters that should not use weight decay."""
  701. nwd = set()
  702. for n, _ in self.named_parameters():
  703. if 'relative_position_bias_table' in n:
  704. nwd.add(n)
  705. return nwd
  706. def set_input_size(
  707. self,
  708. img_size: Optional[Tuple[int, int]] = None,
  709. patch_size: Optional[Tuple[int, int]] = None,
  710. window_size: Optional[Tuple[int, int]] = None,
  711. window_ratio: int = 8,
  712. always_partition: Optional[bool] = None,
  713. ) -> None:
  714. """Update the image resolution and window size.
  715. Args:
  716. img_size: New input resolution, if None current resolution is used.
  717. patch_size: New patch size, if None use current patch size.
  718. window_size: New window size, if None based on new_img_size // window_div.
  719. window_ratio: Divisor for calculating window size from grid size.
  720. always_partition: Always partition into windows and shift (even if window size < feat size).
  721. """
  722. if img_size is not None or patch_size is not None:
  723. self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
  724. patch_grid = self.patch_embed.grid_size
  725. if window_size is None:
  726. window_size = tuple([pg // window_ratio for pg in patch_grid])
  727. for index, stage in enumerate(self.layers):
  728. stage_scale = 2 ** max(index - 1, 0)
  729. stage.set_input_size(
  730. feat_size=(patch_grid[0] // stage_scale, patch_grid[1] // stage_scale),
  731. window_size=window_size,
  732. always_partition=always_partition,
  733. )
  734. @torch.jit.ignore
  735. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  736. """Group parameters for optimization."""
  737. return dict(
  738. stem=r'^patch_embed', # stem and embed
  739. blocks=r'^layers\.(\d+)' if coarse else [
  740. (r'^layers\.(\d+).downsample', (0,)),
  741. (r'^layers\.(\d+)\.\w+\.(\d+)', None),
  742. (r'^norm', (99999,)),
  743. ]
  744. )
  745. @torch.jit.ignore
  746. def set_grad_checkpointing(self, enable: bool = True) -> None:
  747. """Enable or disable gradient checkpointing."""
  748. for l in self.layers:
  749. l.grad_checkpointing = enable
  750. @torch.jit.ignore
  751. def get_classifier(self) -> nn.Module:
  752. """Get the classifier head."""
  753. return self.head.fc
  754. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  755. """Reset the classifier head.
  756. Args:
  757. num_classes: Number of classes for new classifier.
  758. global_pool: Global pooling type.
  759. """
  760. self.num_classes = num_classes
  761. self.head.reset(num_classes, pool_type=global_pool)
  762. def forward_intermediates(
  763. self,
  764. x: torch.Tensor,
  765. indices: Optional[Union[int, List[int]]] = None,
  766. norm: bool = False,
  767. stop_early: bool = False,
  768. output_fmt: str = 'NCHW',
  769. intermediates_only: bool = False,
  770. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  771. """Forward features that returns intermediates.
  772. Args:
  773. x: Input image tensor.
  774. indices: Take last n blocks if int, all if None, select matching indices if sequence.
  775. norm: Apply norm layer to compatible intermediates.
  776. stop_early: Stop iterating over blocks when last desired intermediate hit.
  777. output_fmt: Shape of intermediate feature outputs.
  778. intermediates_only: Only return intermediate features.
  779. Returns:
  780. List of intermediate features or tuple of (final features, intermediates).
  781. """
  782. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  783. intermediates = []
  784. take_indices, max_index = feature_take_indices(len(self.layers), indices)
  785. # forward pass
  786. x = self.patch_embed(x)
  787. num_stages = len(self.layers)
  788. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  789. stages = self.layers
  790. else:
  791. stages = self.layers[:max_index + 1]
  792. for i, stage in enumerate(stages):
  793. x = stage(x)
  794. if i in take_indices:
  795. if norm and i == num_stages - 1:
  796. x_inter = self.norm(x) # applying final norm last intermediate
  797. else:
  798. x_inter = x
  799. x_inter = x_inter.permute(0, 3, 1, 2).contiguous()
  800. intermediates.append(x_inter)
  801. if intermediates_only:
  802. return intermediates
  803. x = self.norm(x)
  804. return x, intermediates
  805. def prune_intermediate_layers(
  806. self,
  807. indices: Union[int, List[int]] = 1,
  808. prune_norm: bool = False,
  809. prune_head: bool = True,
  810. ) -> List[int]:
  811. """Prune layers not required for specified intermediates.
  812. Args:
  813. indices: Indices of intermediate layers to keep.
  814. prune_norm: Whether to prune normalization layer.
  815. prune_head: Whether to prune the classifier head.
  816. Returns:
  817. List of indices that were kept.
  818. """
  819. take_indices, max_index = feature_take_indices(len(self.layers), indices)
  820. self.layers = self.layers[:max_index + 1] # truncate blocks
  821. if prune_norm:
  822. self.norm = nn.Identity()
  823. if prune_head:
  824. self.reset_classifier(0, '')
  825. return take_indices
  826. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  827. """Forward pass through feature extraction layers."""
  828. x = self.patch_embed(x)
  829. x = self.layers(x)
  830. x = self.norm(x)
  831. return x
  832. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  833. """Forward pass through classifier head.
  834. Args:
  835. x: Feature tensor.
  836. pre_logits: Return features before final classifier.
  837. Returns:
  838. Output tensor.
  839. """
  840. return self.head(x, pre_logits=True) if pre_logits else self.head(x)
  841. def forward(self, x: torch.Tensor) -> torch.Tensor:
  842. """Forward pass.
  843. Args:
  844. x: Input tensor.
  845. Returns:
  846. Output logits.
  847. """
  848. x = self.forward_features(x)
  849. x = self.forward_head(x)
  850. return x
  851. def checkpoint_filter_fn(state_dict: dict, model: nn.Module) -> Dict[str, torch.Tensor]:
  852. """Convert patch embedding weight from manual patchify + linear proj to conv.
  853. Args:
  854. state_dict: State dictionary from checkpoint.
  855. model: Model instance.
  856. Returns:
  857. Filtered state dictionary.
  858. """
  859. old_weights = True
  860. if 'head.fc.weight' in state_dict:
  861. old_weights = False
  862. import re
  863. out_dict = {}
  864. state_dict = state_dict.get('model', state_dict)
  865. state_dict = state_dict.get('state_dict', state_dict)
  866. for k, v in state_dict.items():
  867. if any([n in k for n in ('relative_position_index', 'attn_mask')]):
  868. continue # skip buffers that should not be persistent
  869. if 'patch_embed.proj.weight' in k:
  870. _, _, H, W = model.patch_embed.proj.weight.shape
  871. if v.shape[-2] != H or v.shape[-1] != W:
  872. v = resample_patch_embed(
  873. v,
  874. (H, W),
  875. interpolation='bicubic',
  876. antialias=True,
  877. verbose=True,
  878. )
  879. if k.endswith('relative_position_bias_table'):
  880. m = model.get_submodule(k[:-29])
  881. if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
  882. v = resize_rel_pos_bias_table(
  883. v,
  884. new_window_size=m.window_size,
  885. new_bias_shape=m.relative_position_bias_table.shape,
  886. )
  887. if old_weights:
  888. k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
  889. k = k.replace('head.', 'head.fc.')
  890. out_dict[k] = v
  891. return out_dict
  892. def _create_swin_transformer(variant: str, pretrained: bool = False, **kwargs) -> SwinTransformer:
  893. """Create a Swin Transformer model.
  894. Args:
  895. variant: Model variant name.
  896. pretrained: Load pretrained weights.
  897. **kwargs: Additional model arguments.
  898. Returns:
  899. SwinTransformer model instance.
  900. """
  901. default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
  902. out_indices = kwargs.pop('out_indices', default_out_indices)
  903. model = build_model_with_cfg(
  904. SwinTransformer, variant, pretrained,
  905. pretrained_filter_fn=checkpoint_filter_fn,
  906. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  907. **kwargs)
  908. return model
  909. def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  910. """Create default configuration for Swin Transformer models."""
  911. return {
  912. 'url': url,
  913. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  914. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  915. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  916. 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
  917. 'license': 'mit', **kwargs
  918. }
  919. default_cfgs = generate_default_cfgs({
  920. 'swin_small_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
  921. hf_hub_id='timm/',
  922. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22kto1k_finetune.pth', ),
  923. 'swin_base_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
  924. hf_hub_id='timm/',
  925. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',),
  926. 'swin_base_patch4_window12_384.ms_in22k_ft_in1k': _cfg(
  927. hf_hub_id='timm/',
  928. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth',
  929. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  930. 'swin_large_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
  931. hf_hub_id='timm/',
  932. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',),
  933. 'swin_large_patch4_window12_384.ms_in22k_ft_in1k': _cfg(
  934. hf_hub_id='timm/',
  935. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
  936. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  937. 'swin_tiny_patch4_window7_224.ms_in1k': _cfg(
  938. hf_hub_id='timm/',
  939. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',),
  940. 'swin_small_patch4_window7_224.ms_in1k': _cfg(
  941. hf_hub_id='timm/',
  942. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',),
  943. 'swin_base_patch4_window7_224.ms_in1k': _cfg(
  944. hf_hub_id='timm/',
  945. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth',),
  946. 'swin_base_patch4_window12_384.ms_in1k': _cfg(
  947. hf_hub_id='timm/',
  948. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth',
  949. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  950. # tiny 22k pretrain is worse than 1k, so moved after (untagged priority is based on order)
  951. 'swin_tiny_patch4_window7_224.ms_in22k_ft_in1k': _cfg(
  952. hf_hub_id='timm/',
  953. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22kto1k_finetune.pth',),
  954. 'swin_tiny_patch4_window7_224.ms_in22k': _cfg(
  955. hf_hub_id='timm/',
  956. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22k.pth',
  957. num_classes=21841),
  958. 'swin_small_patch4_window7_224.ms_in22k': _cfg(
  959. hf_hub_id='timm/',
  960. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22k.pth',
  961. num_classes=21841),
  962. 'swin_base_patch4_window7_224.ms_in22k': _cfg(
  963. hf_hub_id='timm/',
  964. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
  965. num_classes=21841),
  966. 'swin_base_patch4_window12_384.ms_in22k': _cfg(
  967. hf_hub_id='timm/',
  968. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
  969. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
  970. 'swin_large_patch4_window7_224.ms_in22k': _cfg(
  971. hf_hub_id='timm/',
  972. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth',
  973. num_classes=21841),
  974. 'swin_large_patch4_window12_384.ms_in22k': _cfg(
  975. hf_hub_id='timm/',
  976. url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
  977. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841),
  978. 'swin_s3_tiny_224.ms_in1k': _cfg(
  979. hf_hub_id='timm/',
  980. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'),
  981. 'swin_s3_small_224.ms_in1k': _cfg(
  982. hf_hub_id='timm/',
  983. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'),
  984. 'swin_s3_base_224.ms_in1k': _cfg(
  985. hf_hub_id='timm/',
  986. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'),
  987. })
  988. @register_model
  989. def swin_tiny_patch4_window7_224(pretrained=False, **kwargs) -> SwinTransformer:
  990. """ Swin-T @ 224x224, trained ImageNet-1k
  991. """
  992. model_args = dict(patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24))
  993. return _create_swin_transformer(
  994. 'swin_tiny_patch4_window7_224', pretrained=pretrained, **dict(model_args, **kwargs))
  995. @register_model
  996. def swin_small_patch4_window7_224(pretrained=False, **kwargs) -> SwinTransformer:
  997. """ Swin-S @ 224x224
  998. """
  999. model_args = dict(patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24))
  1000. return _create_swin_transformer(
  1001. 'swin_small_patch4_window7_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1002. @register_model
  1003. def swin_base_patch4_window7_224(pretrained=False, **kwargs) -> SwinTransformer:
  1004. """ Swin-B @ 224x224
  1005. """
  1006. model_args = dict(patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32))
  1007. return _create_swin_transformer(
  1008. 'swin_base_patch4_window7_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1009. @register_model
  1010. def swin_base_patch4_window12_384(pretrained=False, **kwargs) -> SwinTransformer:
  1011. """ Swin-B @ 384x384
  1012. """
  1013. model_args = dict(patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32))
  1014. return _create_swin_transformer(
  1015. 'swin_base_patch4_window12_384', pretrained=pretrained, **dict(model_args, **kwargs))
  1016. @register_model
  1017. def swin_large_patch4_window7_224(pretrained=False, **kwargs) -> SwinTransformer:
  1018. """ Swin-L @ 224x224
  1019. """
  1020. model_args = dict(patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48))
  1021. return _create_swin_transformer(
  1022. 'swin_large_patch4_window7_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1023. @register_model
  1024. def swin_large_patch4_window12_384(pretrained=False, **kwargs) -> SwinTransformer:
  1025. """ Swin-L @ 384x384
  1026. """
  1027. model_args = dict(patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48))
  1028. return _create_swin_transformer(
  1029. 'swin_large_patch4_window12_384', pretrained=pretrained, **dict(model_args, **kwargs))
  1030. @register_model
  1031. def swin_s3_tiny_224(pretrained=False, **kwargs) -> SwinTransformer:
  1032. """ Swin-S3-T @ 224x224, https://arxiv.org/abs/2111.14725
  1033. """
  1034. model_args = dict(
  1035. patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24))
  1036. return _create_swin_transformer('swin_s3_tiny_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1037. @register_model
  1038. def swin_s3_small_224(pretrained=False, **kwargs) -> SwinTransformer:
  1039. """ Swin-S3-S @ 224x224, https://arxiv.org/abs/2111.14725
  1040. """
  1041. model_args = dict(
  1042. patch_size=4, window_size=(14, 14, 14, 7), embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24))
  1043. return _create_swin_transformer('swin_s3_small_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1044. @register_model
  1045. def swin_s3_base_224(pretrained=False, **kwargs) -> SwinTransformer:
  1046. """ Swin-S3-B @ 224x224, https://arxiv.org/abs/2111.14725
  1047. """
  1048. model_args = dict(
  1049. patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 30, 2), num_heads=(3, 6, 12, 24))
  1050. return _create_swin_transformer('swin_s3_base_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1051. register_model_deprecations(__name__, {
  1052. 'swin_base_patch4_window7_224_in22k': 'swin_base_patch4_window7_224.ms_in22k',
  1053. 'swin_base_patch4_window12_384_in22k': 'swin_base_patch4_window12_384.ms_in22k',
  1054. 'swin_large_patch4_window7_224_in22k': 'swin_large_patch4_window7_224.ms_in22k',
  1055. 'swin_large_patch4_window12_384_in22k': 'swin_large_patch4_window12_384.ms_in22k',
  1056. })