tiny_vit.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853
  1. """ TinyViT
  2. Paper: `TinyViT: Fast Pretraining Distillation for Small Vision Transformers`
  3. - https://arxiv.org/abs/2207.10666
  4. Adapted from official impl at https://github.com/microsoft/Cream/tree/main/TinyViT
  5. """
  6. __all__ = ['TinyVit']
  7. import itertools
  8. from functools import partial
  9. from typing import Dict, List, Optional, Tuple, Union, Type, Any
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  14. from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\
  15. trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn, calculate_drop_path_rates
  16. from ._builder import build_model_with_cfg
  17. from ._features import feature_take_indices
  18. from ._features_fx import register_notrace_module
  19. from ._manipulate import checkpoint, checkpoint_seq
  20. from ._registry import register_model, generate_default_cfgs
  21. class ConvNorm(torch.nn.Sequential):
  22. def __init__(
  23. self,
  24. in_chs: int,
  25. out_chs: int,
  26. ks: int = 1,
  27. stride: int = 1,
  28. pad: int = 0,
  29. dilation: int = 1,
  30. groups: int = 1,
  31. bn_weight_init: float = 1,
  32. device=None,
  33. dtype=None,
  34. ):
  35. dd = {'device': device, 'dtype': dtype}
  36. super().__init__()
  37. self.conv = nn.Conv2d(in_chs, out_chs, ks, stride, pad, dilation, groups, bias=False, **dd)
  38. self.bn = nn.BatchNorm2d(out_chs, **dd)
  39. torch.nn.init.constant_(self.bn.weight, bn_weight_init)
  40. torch.nn.init.constant_(self.bn.bias, 0)
  41. @torch.no_grad()
  42. def fuse(self):
  43. c, bn = self.conv, self.bn
  44. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  45. w = c.weight * w[:, None, None, None]
  46. b = bn.bias - bn.running_mean * bn.weight / \
  47. (bn.running_var + bn.eps) ** 0.5
  48. m = torch.nn.Conv2d(
  49. w.size(1) * self.conv.groups, w.size(0), w.shape[2:],
  50. stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups)
  51. m.weight.data.copy_(w)
  52. m.bias.data.copy_(b)
  53. return m
  54. class PatchEmbed(nn.Module):
  55. def __init__(
  56. self,
  57. in_chs: int,
  58. out_chs: int,
  59. act_layer: Type[nn.Module],
  60. device=None,
  61. dtype=None,
  62. ):
  63. dd = {'device': device, 'dtype': dtype}
  64. super().__init__()
  65. self.stride = 4
  66. self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1, **dd)
  67. self.act = act_layer()
  68. self.conv2 = ConvNorm(out_chs // 2, out_chs, 3, 2, 1, **dd)
  69. def forward(self, x):
  70. x = self.conv1(x)
  71. x = self.act(x)
  72. x = self.conv2(x)
  73. return x
  74. class MBConv(nn.Module):
  75. def __init__(
  76. self,
  77. in_chs: int,
  78. out_chs: int,
  79. expand_ratio: float,
  80. act_layer: Type[nn.Module],
  81. drop_path: float,
  82. device=None,
  83. dtype=None,
  84. ):
  85. dd = {'device': device, 'dtype': dtype}
  86. super().__init__()
  87. mid_chs = int(in_chs * expand_ratio)
  88. self.conv1 = ConvNorm(in_chs, mid_chs, ks=1, **dd)
  89. self.act1 = act_layer()
  90. self.conv2 = ConvNorm(mid_chs, mid_chs, ks=3, stride=1, pad=1, groups=mid_chs, **dd)
  91. self.act2 = act_layer()
  92. self.conv3 = ConvNorm(mid_chs, out_chs, ks=1, bn_weight_init=0.0, **dd)
  93. self.act3 = act_layer()
  94. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  95. def forward(self, x):
  96. shortcut = x
  97. x = self.conv1(x)
  98. x = self.act1(x)
  99. x = self.conv2(x)
  100. x = self.act2(x)
  101. x = self.conv3(x)
  102. x = self.drop_path(x)
  103. x += shortcut
  104. x = self.act3(x)
  105. return x
  106. class PatchMerging(nn.Module):
  107. def __init__(
  108. self,
  109. dim: int,
  110. out_dim: int,
  111. act_layer: Type[nn.Module],
  112. device=None,
  113. dtype=None,
  114. ):
  115. dd = {'device': device, 'dtype': dtype}
  116. super().__init__()
  117. self.conv1 = ConvNorm(dim, out_dim, 1, 1, 0, **dd)
  118. self.act1 = act_layer()
  119. self.conv2 = ConvNorm(out_dim, out_dim, 3, 2, 1, groups=out_dim, **dd)
  120. self.act2 = act_layer()
  121. self.conv3 = ConvNorm(out_dim, out_dim, 1, 1, 0, **dd)
  122. def forward(self, x):
  123. x = self.conv1(x)
  124. x = self.act1(x)
  125. x = self.conv2(x)
  126. x = self.act2(x)
  127. x = self.conv3(x)
  128. return x
  129. class ConvLayer(nn.Module):
  130. def __init__(
  131. self,
  132. dim: int,
  133. depth: int,
  134. act_layer: Type[nn.Module],
  135. drop_path: Union[float, List[float]] = 0.,
  136. conv_expand_ratio: float = 4.,
  137. device=None,
  138. dtype=None,
  139. ):
  140. dd = {'device': device, 'dtype': dtype}
  141. super().__init__()
  142. self.dim = dim
  143. self.depth = depth
  144. self.blocks = nn.Sequential(*[
  145. MBConv(
  146. dim,
  147. dim,
  148. conv_expand_ratio,
  149. act_layer,
  150. drop_path[i] if isinstance(drop_path, list) else drop_path,
  151. **dd,
  152. )
  153. for i in range(depth)
  154. ])
  155. def forward(self, x):
  156. x = self.blocks(x)
  157. return x
  158. class NormMlp(nn.Module):
  159. def __init__(
  160. self,
  161. in_features: int,
  162. hidden_features: Optional[int] = None,
  163. out_features: Optional[int] = None,
  164. norm_layer: Type[nn.Module] = nn.LayerNorm,
  165. act_layer: Type[nn.Module] = nn.GELU,
  166. drop: float = 0.,
  167. device=None,
  168. dtype=None,
  169. ):
  170. dd = {'device': device, 'dtype': dtype}
  171. super().__init__()
  172. out_features = out_features or in_features
  173. hidden_features = hidden_features or in_features
  174. self.norm = norm_layer(in_features, **dd)
  175. self.fc1 = nn.Linear(in_features, hidden_features, **dd)
  176. self.act = act_layer()
  177. self.drop1 = nn.Dropout(drop)
  178. self.fc2 = nn.Linear(hidden_features, out_features, **dd)
  179. self.drop2 = nn.Dropout(drop)
  180. def forward(self, x):
  181. x = self.norm(x)
  182. x = self.fc1(x)
  183. x = self.act(x)
  184. x = self.drop1(x)
  185. x = self.fc2(x)
  186. x = self.drop2(x)
  187. return x
  188. class Attention(torch.nn.Module):
  189. fused_attn: torch.jit.Final[bool]
  190. attention_bias_cache: Dict[str, torch.Tensor]
  191. def __init__(
  192. self,
  193. dim: int,
  194. key_dim: int,
  195. num_heads: int = 8,
  196. attn_ratio: int = 4,
  197. resolution: Tuple[int, int] = (14, 14),
  198. device=None,
  199. dtype=None,
  200. ):
  201. dd = {'device': device, 'dtype': dtype}
  202. super().__init__()
  203. assert isinstance(resolution, tuple) and len(resolution) == 2
  204. self.num_heads = num_heads
  205. self.scale = key_dim ** -0.5
  206. self.key_dim = key_dim
  207. self.val_dim = int(attn_ratio * key_dim)
  208. self.out_dim = self.val_dim * num_heads
  209. self.attn_ratio = attn_ratio
  210. self.resolution = resolution
  211. self.fused_attn = use_fused_attn()
  212. self.norm = nn.LayerNorm(dim, **dd)
  213. self.qkv = nn.Linear(dim, num_heads * (self.val_dim + 2 * key_dim), **dd)
  214. self.proj = nn.Linear(self.out_dim, dim, **dd)
  215. points = list(itertools.product(range(resolution[0]), range(resolution[1])))
  216. N = len(points)
  217. attention_offsets = {}
  218. idxs = []
  219. for p1 in points:
  220. for p2 in points:
  221. offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
  222. if offset not in attention_offsets:
  223. attention_offsets[offset] = len(attention_offsets)
  224. idxs.append(attention_offsets[offset])
  225. self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets), **dd))
  226. self.register_buffer(
  227. 'attention_bias_idxs',
  228. torch.tensor(idxs, device=device, dtype=torch.long).view(N, N),
  229. persistent=False,
  230. )
  231. self.attention_bias_cache = {}
  232. @torch.no_grad()
  233. def train(self, mode=True):
  234. super().train(mode)
  235. if mode and self.attention_bias_cache:
  236. self.attention_bias_cache = {} # clear ab cache
  237. def get_attention_biases(self, device: torch.device) -> torch.Tensor:
  238. if torch.jit.is_tracing() or self.training:
  239. return self.attention_biases[:, self.attention_bias_idxs]
  240. else:
  241. device_key = str(device)
  242. if device_key not in self.attention_bias_cache:
  243. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  244. return self.attention_bias_cache[device_key]
  245. def forward(self, x):
  246. attn_bias = self.get_attention_biases(x.device)
  247. B, N, _ = x.shape
  248. # Normalization
  249. x = self.norm(x)
  250. qkv = self.qkv(x)
  251. # (B, N, num_heads, d)
  252. q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
  253. # (B, num_heads, N, d)
  254. q = q.permute(0, 2, 1, 3)
  255. k = k.permute(0, 2, 1, 3)
  256. v = v.permute(0, 2, 1, 3)
  257. if self.fused_attn:
  258. x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
  259. else:
  260. q = q * self.scale
  261. attn = q @ k.transpose(-2, -1)
  262. attn = attn + attn_bias
  263. attn = attn.softmax(dim=-1)
  264. x = attn @ v
  265. x = x.transpose(1, 2).reshape(B, N, self.out_dim)
  266. x = self.proj(x)
  267. return x
  268. class TinyVitBlock(nn.Module):
  269. """ TinyViT Block.
  270. Args:
  271. dim (int): Number of input channels.
  272. num_heads (int): Number of attention heads.
  273. window_size (int): Window size.
  274. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  275. drop (float, optional): Dropout rate. Default: 0.0
  276. drop_path (float, optional): Stochastic depth rate. Default: 0.0
  277. local_conv_size (int): the kernel size of the convolution between
  278. Attention and MLP. Default: 3
  279. act_layer: the activation function. Default: nn.GELU
  280. """
  281. def __init__(
  282. self,
  283. dim: int,
  284. num_heads: int,
  285. window_size: int = 7,
  286. mlp_ratio: float = 4.,
  287. drop: float = 0.,
  288. drop_path: float = 0.,
  289. local_conv_size: int = 3,
  290. act_layer: Type[nn.Module] = nn.GELU,
  291. device=None,
  292. dtype=None,
  293. ):
  294. dd = {'device': device, 'dtype': dtype}
  295. super().__init__()
  296. self.dim = dim
  297. self.num_heads = num_heads
  298. assert window_size > 0, 'window_size must be greater than 0'
  299. self.window_size = window_size
  300. self.mlp_ratio = mlp_ratio
  301. assert dim % num_heads == 0, 'dim must be divisible by num_heads'
  302. head_dim = dim // num_heads
  303. window_resolution = (window_size, window_size)
  304. self.attn = Attention(
  305. dim,
  306. head_dim,
  307. num_heads,
  308. attn_ratio=1,
  309. resolution=window_resolution,
  310. **dd,
  311. )
  312. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  313. self.mlp = NormMlp(
  314. in_features=dim,
  315. hidden_features=int(dim * mlp_ratio),
  316. act_layer=act_layer,
  317. drop=drop,
  318. **dd,
  319. )
  320. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  321. pad = local_conv_size // 2
  322. self.local_conv = ConvNorm(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim, **dd)
  323. def forward(self, x):
  324. B, H, W, C = x.shape
  325. L = H * W
  326. shortcut = x
  327. if H == self.window_size and W == self.window_size:
  328. x = x.reshape(B, L, C)
  329. x = self.attn(x)
  330. x = x.view(B, H, W, C)
  331. else:
  332. pad_b = (self.window_size - H % self.window_size) % self.window_size
  333. pad_r = (self.window_size - W % self.window_size) % self.window_size
  334. padding = pad_b > 0 or pad_r > 0
  335. if padding:
  336. x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
  337. # window partition
  338. pH, pW = H + pad_b, W + pad_r
  339. nH = pH // self.window_size
  340. nW = pW // self.window_size
  341. x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape(
  342. B * nH * nW, self.window_size * self.window_size, C
  343. )
  344. x = self.attn(x)
  345. # window reverse
  346. x = x.view(B, nH, nW, self.window_size, self.window_size, C).transpose(2, 3).reshape(B, pH, pW, C)
  347. if padding:
  348. x = x[:, :H, :W].contiguous()
  349. x = shortcut + self.drop_path1(x)
  350. x = x.permute(0, 3, 1, 2)
  351. x = self.local_conv(x)
  352. x = x.reshape(B, C, L).transpose(1, 2)
  353. x = x + self.drop_path2(self.mlp(x))
  354. return x.view(B, H, W, C)
  355. def extra_repr(self) -> str:
  356. return f"dim={self.dim}, num_heads={self.num_heads}, " \
  357. f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
  358. register_notrace_module(TinyVitBlock)
  359. class TinyVitStage(nn.Module):
  360. """ A basic TinyViT layer for one stage.
  361. Args:
  362. dim (int): Number of input channels.
  363. out_dim: the output dimension of the layer
  364. depth (int): Number of blocks.
  365. num_heads (int): Number of attention heads.
  366. window_size (int): Local window size.
  367. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
  368. drop (float, optional): Dropout rate. Default: 0.0
  369. drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
  370. downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
  371. local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3
  372. act_layer: the activation function. Default: nn.GELU
  373. """
  374. def __init__(
  375. self,
  376. dim: int,
  377. out_dim: int,
  378. depth: int,
  379. num_heads: int,
  380. window_size: int,
  381. mlp_ratio: float = 4.,
  382. drop: float = 0.,
  383. drop_path: Union[float, List[float]] = 0.,
  384. downsample: Optional[Type[nn.Module]] = None,
  385. local_conv_size: int = 3,
  386. act_layer: Type[nn.Module] = nn.GELU,
  387. device=None,
  388. dtype=None,
  389. ):
  390. dd = {'device': device, 'dtype': dtype}
  391. super().__init__()
  392. self.depth = depth
  393. self.out_dim = out_dim
  394. # patch merging layer
  395. if downsample is not None:
  396. self.downsample = downsample(
  397. dim=dim,
  398. out_dim=out_dim,
  399. act_layer=act_layer,
  400. **dd,
  401. )
  402. else:
  403. self.downsample = nn.Identity()
  404. assert dim == out_dim
  405. # build blocks
  406. self.blocks = nn.Sequential(*[
  407. TinyVitBlock(
  408. dim=out_dim,
  409. num_heads=num_heads,
  410. window_size=window_size,
  411. mlp_ratio=mlp_ratio,
  412. drop=drop,
  413. drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
  414. local_conv_size=local_conv_size,
  415. act_layer=act_layer,
  416. **dd,
  417. )
  418. for i in range(depth)])
  419. def forward(self, x):
  420. x = self.downsample(x)
  421. x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
  422. x = self.blocks(x)
  423. x = x.permute(0, 3, 1, 2) # BHWC -> BCHW
  424. return x
  425. def extra_repr(self) -> str:
  426. return f"dim={self.out_dim}, depth={self.depth}"
  427. class TinyVit(nn.Module):
  428. def __init__(
  429. self,
  430. in_chans: int = 3,
  431. num_classes: int = 1000,
  432. global_pool: str = 'avg',
  433. embed_dims: Tuple[int, ...] = (96, 192, 384, 768),
  434. depths: Tuple[int, ...] = (2, 2, 6, 2),
  435. num_heads: Tuple[int, ...] = (3, 6, 12, 24),
  436. window_sizes: Tuple[int, ...] = (7, 7, 14, 7),
  437. mlp_ratio: float = 4.,
  438. drop_rate: float = 0.,
  439. drop_path_rate: float = 0.1,
  440. use_checkpoint: bool = False,
  441. mbconv_expand_ratio: float = 4.0,
  442. local_conv_size: int = 3,
  443. act_layer: Type[nn.Module] = nn.GELU,
  444. device=None,
  445. dtype=None,
  446. ):
  447. super().__init__()
  448. dd = {'device': device, 'dtype': dtype}
  449. self.num_classes = num_classes
  450. self.depths = depths
  451. self.num_stages = len(depths)
  452. self.mlp_ratio = mlp_ratio
  453. self.grad_checkpointing = use_checkpoint
  454. self.patch_embed = PatchEmbed(
  455. in_chs=in_chans,
  456. out_chs=embed_dims[0],
  457. act_layer=act_layer,
  458. **dd,
  459. )
  460. # stochastic depth rate rule
  461. dpr = calculate_drop_path_rates(drop_path_rate, sum(depths))
  462. # build stages
  463. self.stages = nn.Sequential()
  464. stride = self.patch_embed.stride
  465. prev_dim = embed_dims[0]
  466. self.feature_info = []
  467. for stage_idx in range(self.num_stages):
  468. if stage_idx == 0:
  469. stage = ConvLayer(
  470. dim=prev_dim,
  471. depth=depths[stage_idx],
  472. act_layer=act_layer,
  473. drop_path=dpr[:depths[stage_idx]],
  474. conv_expand_ratio=mbconv_expand_ratio,
  475. **dd,
  476. )
  477. else:
  478. out_dim = embed_dims[stage_idx]
  479. drop_path_rate = dpr[sum(depths[:stage_idx]):sum(depths[:stage_idx + 1])]
  480. stage = TinyVitStage(
  481. dim=embed_dims[stage_idx - 1],
  482. out_dim=out_dim,
  483. depth=depths[stage_idx],
  484. num_heads=num_heads[stage_idx],
  485. window_size=window_sizes[stage_idx],
  486. mlp_ratio=self.mlp_ratio,
  487. drop=drop_rate,
  488. local_conv_size=local_conv_size,
  489. drop_path=drop_path_rate,
  490. downsample=PatchMerging,
  491. act_layer=act_layer,
  492. **dd,
  493. )
  494. prev_dim = out_dim
  495. stride *= 2
  496. self.stages.append(stage)
  497. self.feature_info += [dict(num_chs=prev_dim, reduction=stride, module=f'stages.{stage_idx}')]
  498. # Classifier head
  499. self.num_features = self.head_hidden_size = embed_dims[-1]
  500. norm_layer_cf = partial(LayerNorm2d, eps=1e-5)
  501. self.head = NormMlpClassifierHead(
  502. self.num_features,
  503. num_classes,
  504. pool_type=global_pool,
  505. norm_layer=norm_layer_cf,
  506. **dd,
  507. )
  508. # init weights
  509. self.apply(self._init_weights)
  510. def _init_weights(self, m):
  511. if isinstance(m, nn.Linear):
  512. trunc_normal_(m.weight, std=.02)
  513. if isinstance(m, nn.Linear) and m.bias is not None:
  514. nn.init.constant_(m.bias, 0)
  515. @torch.jit.ignore
  516. def no_weight_decay_keywords(self):
  517. return {'attention_biases'}
  518. @torch.jit.ignore
  519. def no_weight_decay(self):
  520. return {x for x in self.state_dict().keys() if 'attention_biases' in x}
  521. @torch.jit.ignore
  522. def group_matcher(self, coarse=False):
  523. matcher = dict(
  524. stem=r'^patch_embed',
  525. blocks=r'^stages\.(\d+)' if coarse else [
  526. (r'^stages\.(\d+).downsample', (0,)),
  527. (r'^stages\.(\d+)\.\w+\.(\d+)', None),
  528. ]
  529. )
  530. return matcher
  531. @torch.jit.ignore
  532. def set_grad_checkpointing(self, enable=True):
  533. self.grad_checkpointing = enable
  534. @torch.jit.ignore
  535. def get_classifier(self) -> nn.Module:
  536. return self.head.fc
  537. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  538. self.num_classes = num_classes
  539. self.head.reset(num_classes, pool_type=global_pool)
  540. def forward_intermediates(
  541. self,
  542. x: torch.Tensor,
  543. indices: Optional[Union[int, List[int]]] = None,
  544. norm: bool = False,
  545. stop_early: bool = False,
  546. output_fmt: str = 'NCHW',
  547. intermediates_only: bool = False,
  548. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  549. """ Forward features that returns intermediates.
  550. Args:
  551. x: Input image tensor
  552. indices: Take last n blocks if int, all if None, select matching indices if sequence
  553. norm: Apply norm layer to compatible intermediates
  554. stop_early: Stop iterating over blocks when last desired intermediate hit
  555. output_fmt: Shape of intermediate feature outputs
  556. intermediates_only: Only return intermediate features
  557. Returns:
  558. """
  559. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  560. intermediates = []
  561. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  562. # forward pass
  563. x = self.patch_embed(x)
  564. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  565. stages = self.stages
  566. else:
  567. stages = self.stages[:max_index + 1]
  568. for feat_idx, stage in enumerate(stages):
  569. if self.grad_checkpointing and not torch.jit.is_scripting():
  570. x = checkpoint(stage, x)
  571. else:
  572. x = stage(x)
  573. if feat_idx in take_indices:
  574. intermediates.append(x)
  575. if intermediates_only:
  576. return intermediates
  577. return x, intermediates
  578. def prune_intermediate_layers(
  579. self,
  580. indices: Union[int, List[int]] = 1,
  581. prune_norm: bool = False,
  582. prune_head: bool = True,
  583. ):
  584. """ Prune layers not required for specified intermediates.
  585. """
  586. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  587. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  588. if prune_head:
  589. self.reset_classifier(0, '')
  590. return take_indices
  591. def forward_features(self, x):
  592. x = self.patch_embed(x)
  593. if self.grad_checkpointing and not torch.jit.is_scripting():
  594. x = checkpoint_seq(self.stages, x)
  595. else:
  596. x = self.stages(x)
  597. return x
  598. def forward_head(self, x, pre_logits: bool = False):
  599. x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  600. return x
  601. def forward(self, x):
  602. x = self.forward_features(x)
  603. x = self.forward_head(x)
  604. return x
  605. def checkpoint_filter_fn(state_dict, model):
  606. if 'model' in state_dict.keys():
  607. state_dict = state_dict['model']
  608. target_sd = model.state_dict()
  609. out_dict = {}
  610. for k, v in state_dict.items():
  611. if k.endswith('attention_bias_idxs'):
  612. continue
  613. if 'attention_biases' in k:
  614. # TODO: whether move this func into model for dynamic input resolution? (high risk)
  615. v = resize_rel_pos_bias_table_levit(v.T, target_sd[k].shape[::-1]).T
  616. out_dict[k] = v
  617. return out_dict
  618. def _cfg(url='', **kwargs):
  619. return {
  620. 'url': url,
  621. 'num_classes': 1000,
  622. 'mean': IMAGENET_DEFAULT_MEAN,
  623. 'std': IMAGENET_DEFAULT_STD,
  624. 'first_conv': 'patch_embed.conv1.conv',
  625. 'classifier': 'head.fc',
  626. 'pool_size': (7, 7),
  627. 'input_size': (3, 224, 224),
  628. 'crop_pct': 0.95,
  629. 'license': 'apache-2.0',
  630. **kwargs,
  631. }
  632. default_cfgs = generate_default_cfgs({
  633. 'tiny_vit_5m_224.dist_in22k': _cfg(
  634. hf_hub_id='timm/',
  635. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22k_distill.pth',
  636. num_classes=21841
  637. ),
  638. 'tiny_vit_5m_224.dist_in22k_ft_in1k': _cfg(
  639. hf_hub_id='timm/',
  640. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_22kto1k_distill.pth'
  641. ),
  642. 'tiny_vit_5m_224.in1k': _cfg(
  643. hf_hub_id='timm/',
  644. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_5m_1k.pth'
  645. ),
  646. 'tiny_vit_11m_224.dist_in22k': _cfg(
  647. hf_hub_id='timm/',
  648. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22k_distill.pth',
  649. num_classes=21841
  650. ),
  651. 'tiny_vit_11m_224.dist_in22k_ft_in1k': _cfg(
  652. hf_hub_id='timm/',
  653. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_22kto1k_distill.pth'
  654. ),
  655. 'tiny_vit_11m_224.in1k': _cfg(
  656. hf_hub_id='timm/',
  657. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_11m_1k.pth'
  658. ),
  659. 'tiny_vit_21m_224.dist_in22k': _cfg(
  660. hf_hub_id='timm/',
  661. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22k_distill.pth',
  662. num_classes=21841
  663. ),
  664. 'tiny_vit_21m_224.dist_in22k_ft_in1k': _cfg(
  665. hf_hub_id='timm/',
  666. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_distill.pth'
  667. ),
  668. 'tiny_vit_21m_224.in1k': _cfg(
  669. hf_hub_id='timm/',
  670. #url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_1k.pth'
  671. ),
  672. 'tiny_vit_21m_384.dist_in22k_ft_in1k': _cfg(
  673. hf_hub_id='timm/',
  674. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_384_distill.pth',
  675. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
  676. ),
  677. 'tiny_vit_21m_512.dist_in22k_ft_in1k': _cfg(
  678. hf_hub_id='timm/',
  679. # url='https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/tiny_vit_21m_22kto1k_512_distill.pth',
  680. input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash',
  681. ),
  682. })
  683. def _create_tiny_vit(variant, pretrained=False, **kwargs):
  684. out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
  685. model = build_model_with_cfg(
  686. TinyVit,
  687. variant,
  688. pretrained,
  689. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  690. pretrained_filter_fn=checkpoint_filter_fn,
  691. **kwargs
  692. )
  693. return model
  694. @register_model
  695. def tiny_vit_5m_224(pretrained=False, **kwargs):
  696. model_kwargs = dict(
  697. embed_dims=[64, 128, 160, 320],
  698. depths=[2, 2, 6, 2],
  699. num_heads=[2, 4, 5, 10],
  700. window_sizes=[7, 7, 14, 7],
  701. drop_path_rate=0.0,
  702. )
  703. model_kwargs.update(kwargs)
  704. return _create_tiny_vit('tiny_vit_5m_224', pretrained, **model_kwargs)
  705. @register_model
  706. def tiny_vit_11m_224(pretrained=False, **kwargs):
  707. model_kwargs = dict(
  708. embed_dims=[64, 128, 256, 448],
  709. depths=[2, 2, 6, 2],
  710. num_heads=[2, 4, 8, 14],
  711. window_sizes=[7, 7, 14, 7],
  712. drop_path_rate=0.1,
  713. )
  714. model_kwargs.update(kwargs)
  715. return _create_tiny_vit('tiny_vit_11m_224', pretrained, **model_kwargs)
  716. @register_model
  717. def tiny_vit_21m_224(pretrained=False, **kwargs):
  718. model_kwargs = dict(
  719. embed_dims=[96, 192, 384, 576],
  720. depths=[2, 2, 6, 2],
  721. num_heads=[3, 6, 12, 18],
  722. window_sizes=[7, 7, 14, 7],
  723. drop_path_rate=0.2,
  724. )
  725. model_kwargs.update(kwargs)
  726. return _create_tiny_vit('tiny_vit_21m_224', pretrained, **model_kwargs)
  727. @register_model
  728. def tiny_vit_21m_384(pretrained=False, **kwargs):
  729. model_kwargs = dict(
  730. embed_dims=[96, 192, 384, 576],
  731. depths=[2, 2, 6, 2],
  732. num_heads=[3, 6, 12, 18],
  733. window_sizes=[12, 12, 24, 12],
  734. drop_path_rate=0.1,
  735. )
  736. model_kwargs.update(kwargs)
  737. return _create_tiny_vit('tiny_vit_21m_384', pretrained, **model_kwargs)
  738. @register_model
  739. def tiny_vit_21m_512(pretrained=False, **kwargs):
  740. model_kwargs = dict(
  741. embed_dims=[96, 192, 384, 576],
  742. depths=[2, 2, 6, 2],
  743. num_heads=[3, 6, 12, 18],
  744. window_sizes=[16, 16, 32, 16],
  745. drop_path_rate=0.1,
  746. )
  747. model_kwargs.update(kwargs)
  748. return _create_tiny_vit('tiny_vit_21m_512', pretrained, **model_kwargs)