efficientvit_msra.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795
  1. """ EfficientViT (by MSRA)
  2. Paper: `EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention`
  3. - https://arxiv.org/abs/2305.07027
  4. Adapted from official impl at https://github.com/microsoft/Cream/tree/main/EfficientViT
  5. """
  6. __all__ = ['EfficientVitMsra']
  7. import itertools
  8. from collections import OrderedDict
  9. from typing import Dict, List, Optional, Tuple, Type, Union
  10. import torch
  11. import torch.nn as nn
  12. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  13. from timm.layers import SqueezeExcite, SelectAdaptivePool2d, trunc_normal_, _assert
  14. from ._builder import build_model_with_cfg
  15. from ._features import feature_take_indices
  16. from ._manipulate import checkpoint, checkpoint_seq
  17. from ._registry import register_model, generate_default_cfgs
  18. class ConvNorm(torch.nn.Sequential):
  19. def __init__(
  20. self,
  21. in_chs: int,
  22. out_chs: int,
  23. ks: int = 1,
  24. stride: int = 1,
  25. pad: int = 0,
  26. dilation: int = 1,
  27. groups: int = 1,
  28. bn_weight_init: float = 1,
  29. device=None,
  30. dtype=None,
  31. ):
  32. dd = {'device': device, 'dtype': dtype}
  33. super().__init__()
  34. self.conv = nn.Conv2d(in_chs, out_chs, ks, stride, pad, dilation, groups, bias=False, **dd)
  35. self.bn = nn.BatchNorm2d(out_chs, **dd)
  36. torch.nn.init.constant_(self.bn.weight, bn_weight_init)
  37. @torch.no_grad()
  38. def fuse(self):
  39. c, bn = self.conv, self.bn
  40. w = bn.weight / (bn.running_var + bn.eps)**0.5
  41. w = c.weight * w[:, None, None, None]
  42. b = bn.bias - bn.running_mean * bn.weight / \
  43. (bn.running_var + bn.eps)**0.5
  44. m = torch.nn.Conv2d(
  45. w.size(1) * self.conv.groups, w.size(0), w.shape[2:],
  46. stride=self.conv.stride, padding=self.conv.padding, dilation=self.conv.dilation, groups=self.conv.groups)
  47. m.weight.data.copy_(w)
  48. m.bias.data.copy_(b)
  49. return m
  50. class NormLinear(torch.nn.Sequential):
  51. def __init__(
  52. self,
  53. in_features: int,
  54. out_features: int,
  55. bias: bool = True,
  56. std: float = 0.02,
  57. drop: float = 0.,
  58. device=None,
  59. dtype=None,
  60. ):
  61. dd = {'device': device, 'dtype': dtype}
  62. super().__init__()
  63. self.bn = nn.BatchNorm1d(in_features, **dd)
  64. self.drop = nn.Dropout(drop)
  65. self.linear = nn.Linear(in_features, out_features, bias=bias, **dd)
  66. trunc_normal_(self.linear.weight, std=std)
  67. if self.linear.bias is not None:
  68. nn.init.constant_(self.linear.bias, 0)
  69. @torch.no_grad()
  70. def fuse(self):
  71. bn, linear = self.bn, self.linear
  72. w = bn.weight / (bn.running_var + bn.eps)**0.5
  73. b = bn.bias - self.bn.running_mean * \
  74. self.bn.weight / (bn.running_var + bn.eps)**0.5
  75. w = linear.weight * w[None, :]
  76. if linear.bias is None:
  77. b = b @ self.linear.weight.T
  78. else:
  79. b = (linear.weight @ b[:, None]).view(-1) + self.linear.bias
  80. m = torch.nn.Linear(w.size(1), w.size(0))
  81. m.weight.data.copy_(w)
  82. m.bias.data.copy_(b)
  83. return m
  84. class PatchMerging(torch.nn.Module):
  85. def __init__(
  86. self,
  87. dim: int,
  88. out_dim: int,
  89. device=None,
  90. dtype=None,
  91. ):
  92. dd = {'device': device, 'dtype': dtype}
  93. super().__init__()
  94. hid_dim = int(dim * 4)
  95. self.conv1 = ConvNorm(dim, hid_dim, 1, 1, 0, **dd)
  96. self.act = torch.nn.ReLU()
  97. self.conv2 = ConvNorm(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim, **dd)
  98. self.se = SqueezeExcite(hid_dim, .25, **dd)
  99. self.conv3 = ConvNorm(hid_dim, out_dim, 1, 1, 0, **dd)
  100. def forward(self, x):
  101. x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))
  102. return x
  103. class ResidualDrop(torch.nn.Module):
  104. def __init__(self, m: nn.Module, drop: float = 0.):
  105. super().__init__()
  106. self.m = m
  107. self.drop = drop
  108. def forward(self, x):
  109. if self.training and self.drop > 0:
  110. return x + self.m(x) * torch.rand(
  111. x.size(0), 1, 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach()
  112. else:
  113. return x + self.m(x)
  114. class ConvMlp(torch.nn.Module):
  115. def __init__(
  116. self,
  117. ed: int,
  118. h: int,
  119. device=None,
  120. dtype=None,
  121. ):
  122. dd = {'device': device, 'dtype': dtype}
  123. super().__init__()
  124. self.pw1 = ConvNorm(ed, h, **dd)
  125. self.act = torch.nn.ReLU()
  126. self.pw2 = ConvNorm(h, ed, bn_weight_init=0, **dd)
  127. def forward(self, x):
  128. x = self.pw2(self.act(self.pw1(x)))
  129. return x
  130. class CascadedGroupAttention(torch.nn.Module):
  131. attention_bias_cache: Dict[str, torch.Tensor]
  132. r""" Cascaded Group Attention.
  133. Args:
  134. dim (int): Number of input channels.
  135. key_dim (int): The dimension for query and key.
  136. num_heads (int): Number of attention heads.
  137. attn_ratio (int): Multiplier for the query dim for value dimension.
  138. resolution (int): Input resolution, correspond to the window size.
  139. kernels (List[int]): The kernel size of the dw conv on query.
  140. """
  141. def __init__(
  142. self,
  143. dim: int,
  144. key_dim: int,
  145. num_heads: int = 8,
  146. attn_ratio: int = 4,
  147. resolution: int = 14,
  148. kernels: Tuple[int, ...] = (5, 5, 5, 5),
  149. device=None,
  150. dtype=None,
  151. ):
  152. dd = {'device': device, 'dtype': dtype}
  153. super().__init__()
  154. self.num_heads = num_heads
  155. self.scale = key_dim ** -0.5
  156. self.key_dim = key_dim
  157. self.val_dim = int(attn_ratio * key_dim)
  158. self.attn_ratio = attn_ratio
  159. qkvs = []
  160. dws = []
  161. for i in range(num_heads):
  162. qkvs.append(ConvNorm(dim // num_heads, self.key_dim * 2 + self.val_dim, **dd))
  163. dws.append(ConvNorm(self.key_dim, self.key_dim, kernels[i], 1, kernels[i] // 2, groups=self.key_dim, **dd))
  164. self.qkvs = torch.nn.ModuleList(qkvs)
  165. self.dws = torch.nn.ModuleList(dws)
  166. self.proj = torch.nn.Sequential(
  167. torch.nn.ReLU(),
  168. ConvNorm(self.val_dim * num_heads, dim, bn_weight_init=0, **dd)
  169. )
  170. points = list(itertools.product(range(resolution), range(resolution)))
  171. N = len(points)
  172. attention_offsets = {}
  173. idxs = []
  174. for p1 in points:
  175. for p2 in points:
  176. offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
  177. if offset not in attention_offsets:
  178. attention_offsets[offset] = len(attention_offsets)
  179. idxs.append(attention_offsets[offset])
  180. self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets), **dd))
  181. self.register_buffer(
  182. 'attention_bias_idxs',
  183. torch.tensor(idxs, device=device, dtype=torch.long).view(N, N),
  184. persistent=False,
  185. )
  186. self.attention_bias_cache = {}
  187. @torch.no_grad()
  188. def train(self, mode=True):
  189. super().train(mode)
  190. if mode and self.attention_bias_cache:
  191. self.attention_bias_cache = {} # clear ab cache
  192. def get_attention_biases(self, device: torch.device) -> torch.Tensor:
  193. if torch.jit.is_tracing() or self.training:
  194. return self.attention_biases[:, self.attention_bias_idxs]
  195. else:
  196. device_key = str(device)
  197. if device_key not in self.attention_bias_cache:
  198. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  199. return self.attention_bias_cache[device_key]
  200. def forward(self, x):
  201. B, C, H, W = x.shape
  202. feats_in = x.chunk(len(self.qkvs), dim=1)
  203. feats_out = []
  204. feat = feats_in[0]
  205. attn_bias = self.get_attention_biases(x.device)
  206. for head_idx, (qkv, dws) in enumerate(zip(self.qkvs, self.dws)):
  207. if head_idx > 0:
  208. feat = feat + feats_in[head_idx]
  209. feat = qkv(feat)
  210. q, k, v = feat.view(B, -1, H, W).split([self.key_dim, self.key_dim, self.val_dim], dim=1)
  211. q = dws(q)
  212. q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
  213. q = q * self.scale
  214. attn = q.transpose(-2, -1) @ k
  215. attn = attn + attn_bias[head_idx]
  216. attn = attn.softmax(dim=-1)
  217. feat = v @ attn.transpose(-2, -1)
  218. feat = feat.view(B, self.val_dim, H, W)
  219. feats_out.append(feat)
  220. x = self.proj(torch.cat(feats_out, 1))
  221. return x
  222. class LocalWindowAttention(torch.nn.Module):
  223. r""" Local Window Attention.
  224. Args:
  225. dim (int): Number of input channels.
  226. key_dim (int): The dimension for query and key.
  227. num_heads (int): Number of attention heads.
  228. attn_ratio (int): Multiplier for the query dim for value dimension.
  229. resolution (int): Input resolution.
  230. window_resolution (int): Local window resolution.
  231. kernels (List[int]): The kernel size of the dw conv on query.
  232. """
  233. def __init__(
  234. self,
  235. dim: int,
  236. key_dim: int,
  237. num_heads: int = 8,
  238. attn_ratio: int = 4,
  239. resolution: int = 14,
  240. window_resolution: int = 7,
  241. kernels: Tuple[int, ...] = (5, 5, 5, 5),
  242. device=None,
  243. dtype=None,
  244. ):
  245. dd = {'device': device, 'dtype': dtype}
  246. super().__init__()
  247. self.dim = dim
  248. self.num_heads = num_heads
  249. self.resolution = resolution
  250. assert window_resolution > 0, 'window_size must be greater than 0'
  251. self.window_resolution = window_resolution
  252. window_resolution = min(window_resolution, resolution)
  253. self.attn = CascadedGroupAttention(
  254. dim, key_dim, num_heads,
  255. attn_ratio=attn_ratio,
  256. resolution=window_resolution,
  257. kernels=kernels,
  258. **dd,
  259. )
  260. def forward(self, x):
  261. H = W = self.resolution
  262. B, C, H_, W_ = x.shape
  263. # Only check this for classification models
  264. _assert(H == H_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}')
  265. _assert(W == W_, f'input feature has wrong size, expect {(H, W)}, got {(H_, W_)}')
  266. if H <= self.window_resolution and W <= self.window_resolution:
  267. x = self.attn(x)
  268. else:
  269. x = x.permute(0, 2, 3, 1)
  270. pad_b = (self.window_resolution - H % self.window_resolution) % self.window_resolution
  271. pad_r = (self.window_resolution - W % self.window_resolution) % self.window_resolution
  272. x = torch.nn.functional.pad(x, (0, 0, 0, pad_r, 0, pad_b))
  273. pH, pW = H + pad_b, W + pad_r
  274. nH = pH // self.window_resolution
  275. nW = pW // self.window_resolution
  276. # window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw
  277. x = x.view(B, nH, self.window_resolution, nW, self.window_resolution, C).transpose(2, 3)
  278. x = x.reshape(B * nH * nW, self.window_resolution, self.window_resolution, C).permute(0, 3, 1, 2)
  279. x = self.attn(x)
  280. # window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC
  281. x = x.permute(0, 2, 3, 1).view(B, nH, nW, self.window_resolution, self.window_resolution, C)
  282. x = x.transpose(2, 3).reshape(B, pH, pW, C)
  283. x = x[:, :H, :W].contiguous()
  284. x = x.permute(0, 3, 1, 2)
  285. return x
  286. class EfficientVitBlock(torch.nn.Module):
  287. """ A basic EfficientVit building block.
  288. Args:
  289. dim (int): Number of input channels.
  290. key_dim (int): Dimension for query and key in the token mixer.
  291. num_heads (int): Number of attention heads.
  292. attn_ratio (int): Multiplier for the query dim for value dimension.
  293. resolution (int): Input resolution.
  294. window_resolution (int): Local window resolution.
  295. kernels (List[int]): The kernel size of the dw conv on query.
  296. """
  297. def __init__(
  298. self,
  299. dim: int,
  300. key_dim: int,
  301. num_heads: int = 8,
  302. attn_ratio: int = 4,
  303. resolution: int = 14,
  304. window_resolution: int = 7,
  305. kernels: List[int] = [5, 5, 5, 5],
  306. device=None,
  307. dtype=None,
  308. ):
  309. dd = {'device': device, 'dtype': dtype}
  310. super().__init__()
  311. self.dw0 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0., **dd))
  312. self.ffn0 = ResidualDrop(ConvMlp(dim, int(dim * 2), **dd))
  313. self.mixer = ResidualDrop(
  314. LocalWindowAttention(
  315. dim, key_dim, num_heads,
  316. attn_ratio=attn_ratio,
  317. resolution=resolution,
  318. window_resolution=window_resolution,
  319. kernels=kernels,
  320. **dd,
  321. ),
  322. )
  323. self.dw1 = ResidualDrop(ConvNorm(dim, dim, 3, 1, 1, groups=dim, bn_weight_init=0., **dd))
  324. self.ffn1 = ResidualDrop(ConvMlp(dim, int(dim * 2), **dd))
  325. def forward(self, x):
  326. return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))
  327. class EfficientVitStage(torch.nn.Module):
  328. def __init__(
  329. self,
  330. in_dim: int,
  331. out_dim: int,
  332. key_dim: int,
  333. downsample: Tuple[str, int] = ('', 1),
  334. num_heads: int = 8,
  335. attn_ratio: int = 4,
  336. resolution: int = 14,
  337. window_resolution: int = 7,
  338. kernels: List[int] = [5, 5, 5, 5],
  339. depth: int = 1,
  340. device=None,
  341. dtype=None,
  342. ):
  343. dd = {'device': device, 'dtype': dtype}
  344. super().__init__()
  345. if downsample[0] == 'subsample':
  346. self.resolution = (resolution - 1) // downsample[1] + 1
  347. down_blocks = []
  348. down_blocks.append((
  349. 'res1',
  350. torch.nn.Sequential(
  351. ResidualDrop(ConvNorm(in_dim, in_dim, 3, 1, 1, groups=in_dim, **dd)),
  352. ResidualDrop(ConvMlp(in_dim, int(in_dim * 2), **dd)),
  353. )
  354. ))
  355. down_blocks.append(('patchmerge', PatchMerging(in_dim, out_dim, **dd)))
  356. down_blocks.append((
  357. 'res2',
  358. torch.nn.Sequential(
  359. ResidualDrop(ConvNorm(out_dim, out_dim, 3, 1, 1, groups=out_dim, **dd)),
  360. ResidualDrop(ConvMlp(out_dim, int(out_dim * 2), **dd)),
  361. )
  362. ))
  363. self.downsample = nn.Sequential(OrderedDict(down_blocks))
  364. else:
  365. assert in_dim == out_dim
  366. self.downsample = nn.Identity()
  367. self.resolution = resolution
  368. blocks = []
  369. for d in range(depth):
  370. blocks.append(EfficientVitBlock(
  371. out_dim,
  372. key_dim,
  373. num_heads,
  374. attn_ratio,
  375. self.resolution,
  376. window_resolution,
  377. kernels,
  378. **dd,
  379. ))
  380. self.blocks = nn.Sequential(*blocks)
  381. def forward(self, x):
  382. x = self.downsample(x)
  383. x = self.blocks(x)
  384. return x
  385. class PatchEmbedding(torch.nn.Sequential):
  386. def __init__(
  387. self,
  388. in_chans: int,
  389. dim: int,
  390. device=None,
  391. dtype=None,
  392. ):
  393. super().__init__()
  394. dd = {'device': device, 'dtype': dtype}
  395. self.add_module('conv1', ConvNorm(in_chans, dim // 8, 3, 2, 1, **dd))
  396. self.add_module('relu1', torch.nn.ReLU())
  397. self.add_module('conv2', ConvNorm(dim // 8, dim // 4, 3, 2, 1, **dd))
  398. self.add_module('relu2', torch.nn.ReLU())
  399. self.add_module('conv3', ConvNorm(dim // 4, dim // 2, 3, 2, 1, **dd))
  400. self.add_module('relu3', torch.nn.ReLU())
  401. self.add_module('conv4', ConvNorm(dim // 2, dim, 3, 2, 1, **dd))
  402. self.patch_size = 16
  403. class EfficientVitMsra(nn.Module):
  404. def __init__(
  405. self,
  406. img_size: int = 224,
  407. in_chans: int = 3,
  408. num_classes: int = 1000,
  409. embed_dim: Tuple[int, ...] = (64, 128, 192),
  410. key_dim: Tuple[int, ...] = (16, 16, 16),
  411. depth: Tuple[int, ...] = (1, 2, 3),
  412. num_heads: Tuple[int, ...] = (4, 4, 4),
  413. window_size: Tuple[int, ...] = (7, 7, 7),
  414. kernels: Tuple[int, ...] = (5, 5, 5, 5),
  415. down_ops: Tuple[Tuple[str, int], ...] = (('', 1), ('subsample', 2), ('subsample', 2)),
  416. global_pool: str = 'avg',
  417. drop_rate: float = 0.,
  418. device=None,
  419. dtype=None,
  420. ):
  421. super().__init__()
  422. dd = {'device': device, 'dtype': dtype}
  423. self.grad_checkpointing = False
  424. self.num_classes = num_classes
  425. self.drop_rate = drop_rate
  426. # Patch embedding
  427. self.patch_embed = PatchEmbedding(in_chans, embed_dim[0], **dd)
  428. stride = self.patch_embed.patch_size
  429. resolution = img_size // self.patch_embed.patch_size
  430. attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))]
  431. # Build EfficientVit blocks
  432. self.feature_info = []
  433. stages = []
  434. pre_ed = embed_dim[0]
  435. for i, (ed, kd, dpth, nh, ar, wd, do) in enumerate(
  436. zip(embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)):
  437. stage = EfficientVitStage(
  438. in_dim=pre_ed,
  439. out_dim=ed,
  440. key_dim=kd,
  441. downsample=do,
  442. num_heads=nh,
  443. attn_ratio=ar,
  444. resolution=resolution,
  445. window_resolution=wd,
  446. kernels=kernels,
  447. depth=dpth,
  448. **dd,
  449. )
  450. pre_ed = ed
  451. if do[0] == 'subsample' and i != 0:
  452. stride *= do[1]
  453. resolution = stage.resolution
  454. stages.append(stage)
  455. self.feature_info += [dict(num_chs=ed, reduction=stride, module=f'stages.{i}')]
  456. self.stages = nn.Sequential(*stages)
  457. if global_pool == 'avg':
  458. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
  459. else:
  460. assert num_classes == 0
  461. self.global_pool = nn.Identity()
  462. self.num_features = self.head_hidden_size = embed_dim[-1]
  463. self.head = NormLinear(
  464. self.num_features, num_classes, drop=self.drop_rate, **dd) if num_classes > 0 else torch.nn.Identity()
  465. @torch.jit.ignore
  466. def no_weight_decay(self):
  467. return {x for x in self.state_dict().keys() if 'attention_biases' in x}
  468. @torch.jit.ignore
  469. def group_matcher(self, coarse=False):
  470. matcher = dict(
  471. stem=r'^patch_embed',
  472. blocks=r'^stages\.(\d+)' if coarse else [
  473. (r'^stages\.(\d+).downsample', (0,)),
  474. (r'^stages\.(\d+)\.\w+\.(\d+)', None),
  475. ]
  476. )
  477. return matcher
  478. @torch.jit.ignore
  479. def set_grad_checkpointing(self, enable=True):
  480. self.grad_checkpointing = enable
  481. @torch.jit.ignore
  482. def get_classifier(self) -> nn.Module:
  483. return self.head.linear
  484. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  485. self.num_classes = num_classes
  486. if global_pool is not None:
  487. if global_pool == 'avg':
  488. self.global_pool = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
  489. else:
  490. assert num_classes == 0
  491. self.global_pool = nn.Identity()
  492. self.head = NormLinear(
  493. self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else torch.nn.Identity()
  494. def forward_intermediates(
  495. self,
  496. x: torch.Tensor,
  497. indices: Optional[Union[int, List[int]]] = None,
  498. norm: bool = False,
  499. stop_early: bool = False,
  500. output_fmt: str = 'NCHW',
  501. intermediates_only: bool = False,
  502. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  503. """ Forward features that returns intermediates.
  504. Args:
  505. x: Input image tensor
  506. indices: Take last n blocks if int, all if None, select matching indices if sequence
  507. norm: Apply norm layer to compatible intermediates
  508. stop_early: Stop iterating over blocks when last desired intermediate hit
  509. output_fmt: Shape of intermediate feature outputs
  510. intermediates_only: Only return intermediate features
  511. Returns:
  512. """
  513. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  514. intermediates = []
  515. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  516. # forward pass
  517. x = self.patch_embed(x)
  518. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  519. stages = self.stages
  520. else:
  521. stages = self.stages[:max_index + 1]
  522. for feat_idx, stage in enumerate(stages):
  523. if self.grad_checkpointing and not torch.jit.is_scripting():
  524. x = checkpoint(stage, x)
  525. else:
  526. x = stage(x)
  527. if feat_idx in take_indices:
  528. intermediates.append(x)
  529. if intermediates_only:
  530. return intermediates
  531. return x, intermediates
  532. def prune_intermediate_layers(
  533. self,
  534. indices: Union[int, List[int]] = 1,
  535. prune_norm: bool = False,
  536. prune_head: bool = True,
  537. ):
  538. """ Prune layers not required for specified intermediates.
  539. """
  540. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  541. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  542. if prune_head:
  543. self.reset_classifier(0, '')
  544. return take_indices
  545. def forward_features(self, x):
  546. x = self.patch_embed(x)
  547. if self.grad_checkpointing and not torch.jit.is_scripting():
  548. x = checkpoint_seq(self.stages, x)
  549. else:
  550. x = self.stages(x)
  551. return x
  552. def forward_head(self, x, pre_logits: bool = False):
  553. x = self.global_pool(x)
  554. return x if pre_logits else self.head(x)
  555. def forward(self, x):
  556. x = self.forward_features(x)
  557. x = self.forward_head(x)
  558. return x
  559. # def checkpoint_filter_fn(state_dict, model):
  560. # if 'model' in state_dict.keys():
  561. # state_dict = state_dict['model']
  562. # tmp_dict = {}
  563. # out_dict = {}
  564. # target_keys = model.state_dict().keys()
  565. # target_keys = [k for k in target_keys if k.startswith('stages.')]
  566. #
  567. # for k, v in state_dict.items():
  568. # if 'attention_bias_idxs' in k:
  569. # continue
  570. # k = k.split('.')
  571. # if k[-2] == 'c':
  572. # k[-2] = 'conv'
  573. # if k[-2] == 'l':
  574. # k[-2] = 'linear'
  575. # k = '.'.join(k)
  576. # tmp_dict[k] = v
  577. #
  578. # for k, v in tmp_dict.items():
  579. # if k.startswith('patch_embed'):
  580. # k = k.split('.')
  581. # k[1] = 'conv' + str(int(k[1]) // 2 + 1)
  582. # k = '.'.join(k)
  583. # elif k.startswith('blocks'):
  584. # kw = '.'.join(k.split('.')[2:])
  585. # find_kw = [a for a in list(sorted(tmp_dict.keys())) if kw in a]
  586. # idx = find_kw.index(k)
  587. # k = [a for a in target_keys if kw in a][idx]
  588. # out_dict[k] = v
  589. #
  590. # return out_dict
  591. def _cfg(url='', **kwargs):
  592. return {
  593. 'url': url,
  594. 'num_classes': 1000,
  595. 'mean': IMAGENET_DEFAULT_MEAN,
  596. 'std': IMAGENET_DEFAULT_STD,
  597. 'first_conv': 'patch_embed.conv1.conv',
  598. 'classifier': 'head.linear',
  599. 'fixed_input_size': True,
  600. 'pool_size': (4, 4),
  601. 'license': 'mit',
  602. **kwargs,
  603. }
  604. default_cfgs = generate_default_cfgs({
  605. 'efficientvit_m0.r224_in1k': _cfg(
  606. hf_hub_id='timm/',
  607. #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m0.pth'
  608. ),
  609. 'efficientvit_m1.r224_in1k': _cfg(
  610. hf_hub_id='timm/',
  611. #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m1.pth'
  612. ),
  613. 'efficientvit_m2.r224_in1k': _cfg(
  614. hf_hub_id='timm/',
  615. #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m2.pth'
  616. ),
  617. 'efficientvit_m3.r224_in1k': _cfg(
  618. hf_hub_id='timm/',
  619. #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m3.pth'
  620. ),
  621. 'efficientvit_m4.r224_in1k': _cfg(
  622. hf_hub_id='timm/',
  623. #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m4.pth'
  624. ),
  625. 'efficientvit_m5.r224_in1k': _cfg(
  626. hf_hub_id='timm/',
  627. #url='https://github.com/xinyuliu-jeffrey/EfficientVit_Model_Zoo/releases/download/v1.0/efficientvit_m5.pth'
  628. ),
  629. })
  630. def _create_efficientvit_msra(variant, pretrained=False, **kwargs):
  631. out_indices = kwargs.pop('out_indices', (0, 1, 2))
  632. model = build_model_with_cfg(
  633. EfficientVitMsra,
  634. variant,
  635. pretrained,
  636. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  637. **kwargs
  638. )
  639. return model
  640. @register_model
  641. def efficientvit_m0(pretrained=False, **kwargs):
  642. model_args = dict(
  643. img_size=224,
  644. embed_dim=[64, 128, 192],
  645. depth=[1, 2, 3],
  646. num_heads=[4, 4, 4],
  647. window_size=[7, 7, 7],
  648. kernels=[5, 5, 5, 5]
  649. )
  650. return _create_efficientvit_msra('efficientvit_m0', pretrained=pretrained, **dict(model_args, **kwargs))
  651. @register_model
  652. def efficientvit_m1(pretrained=False, **kwargs):
  653. model_args = dict(
  654. img_size=224,
  655. embed_dim=[128, 144, 192],
  656. depth=[1, 2, 3],
  657. num_heads=[2, 3, 3],
  658. window_size=[7, 7, 7],
  659. kernels=[7, 5, 3, 3]
  660. )
  661. return _create_efficientvit_msra('efficientvit_m1', pretrained=pretrained, **dict(model_args, **kwargs))
  662. @register_model
  663. def efficientvit_m2(pretrained=False, **kwargs):
  664. model_args = dict(
  665. img_size=224,
  666. embed_dim=[128, 192, 224],
  667. depth=[1, 2, 3],
  668. num_heads=[4, 3, 2],
  669. window_size=[7, 7, 7],
  670. kernels=[7, 5, 3, 3]
  671. )
  672. return _create_efficientvit_msra('efficientvit_m2', pretrained=pretrained, **dict(model_args, **kwargs))
  673. @register_model
  674. def efficientvit_m3(pretrained=False, **kwargs):
  675. model_args = dict(
  676. img_size=224,
  677. embed_dim=[128, 240, 320],
  678. depth=[1, 2, 3],
  679. num_heads=[4, 3, 4],
  680. window_size=[7, 7, 7],
  681. kernels=[5, 5, 5, 5]
  682. )
  683. return _create_efficientvit_msra('efficientvit_m3', pretrained=pretrained, **dict(model_args, **kwargs))
  684. @register_model
  685. def efficientvit_m4(pretrained=False, **kwargs):
  686. model_args = dict(
  687. img_size=224,
  688. embed_dim=[128, 256, 384],
  689. depth=[1, 2, 3],
  690. num_heads=[4, 4, 4],
  691. window_size=[7, 7, 7],
  692. kernels=[7, 5, 3, 3]
  693. )
  694. return _create_efficientvit_msra('efficientvit_m4', pretrained=pretrained, **dict(model_args, **kwargs))
  695. @register_model
  696. def efficientvit_m5(pretrained=False, **kwargs):
  697. model_args = dict(
  698. img_size=224,
  699. embed_dim=[192, 288, 384],
  700. depth=[1, 3, 4],
  701. num_heads=[3, 3, 4],
  702. window_size=[7, 7, 7],
  703. kernels=[7, 5, 3, 3]
  704. )
  705. return _create_efficientvit_msra('efficientvit_m5', pretrained=pretrained, **dict(model_args, **kwargs))