efficientformer_v2.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885
  1. """ EfficientFormer-V2
  2. @article{
  3. li2022rethinking,
  4. title={Rethinking Vision Transformers for MobileNet Size and Speed},
  5. author={Li, Yanyu and Hu, Ju and Wen, Yang and Evangelidis, Georgios and Salahi, Kamyar and Wang, Yanzhi and Tulyakov, Sergey and Ren, Jian},
  6. journal={arXiv preprint arXiv:2212.08059},
  7. year={2022}
  8. }
  9. Significantly refactored and cleaned up for timm from original at: https://github.com/snap-research/EfficientFormer
  10. Original code licensed Apache 2.0, Copyright (c) 2022 Snap Inc.
  11. Modifications and timm support by / Copyright 2023, Ross Wightman
  12. """
  13. import math
  14. from functools import partial
  15. from typing import Dict, List, Optional, Tuple, Type, Union
  16. import torch
  17. import torch.nn as nn
  18. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  19. from timm.layers import (
  20. create_conv2d,
  21. create_norm_layer,
  22. get_act_layer,
  23. get_norm_layer,
  24. ConvNormAct,
  25. LayerScale2d,
  26. DropPath,
  27. calculate_drop_path_rates,
  28. trunc_normal_,
  29. to_2tuple,
  30. to_ntuple,
  31. ndgrid,
  32. )
  33. from ._builder import build_model_with_cfg
  34. from ._features import feature_take_indices
  35. from ._manipulate import checkpoint_seq
  36. from ._registry import generate_default_cfgs, register_model
  37. __all__ = ['EfficientFormerV2']
  38. EfficientFormer_width = {
  39. 'L': (40, 80, 192, 384), # 26m 83.3% 6attn
  40. 'S2': (32, 64, 144, 288), # 12m 81.6% 4attn dp0.02
  41. 'S1': (32, 48, 120, 224), # 6.1m 79.0
  42. 'S0': (32, 48, 96, 176), # 75.0 75.7
  43. }
  44. EfficientFormer_depth = {
  45. 'L': (5, 5, 15, 10), # 26m 83.3%
  46. 'S2': (4, 4, 12, 8), # 12m
  47. 'S1': (3, 3, 9, 6), # 79.0
  48. 'S0': (2, 2, 6, 4), # 75.7
  49. }
  50. EfficientFormer_expansion_ratios = {
  51. 'L': (4, 4, (4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4), (4, 4, 4, 3, 3, 3, 3, 4, 4, 4)),
  52. 'S2': (4, 4, (4, 4, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4), (4, 4, 3, 3, 3, 3, 4, 4)),
  53. 'S1': (4, 4, (4, 4, 3, 3, 3, 3, 4, 4, 4), (4, 4, 3, 3, 4, 4)),
  54. 'S0': (4, 4, (4, 3, 3, 3, 4, 4), (4, 3, 3, 4)),
  55. }
  56. class ConvNorm(nn.Module):
  57. def __init__(
  58. self,
  59. in_channels: int,
  60. out_channels: int,
  61. kernel_size: int = 1,
  62. stride: int = 1,
  63. padding: Union[int, str] = '',
  64. dilation: int = 1,
  65. groups: int = 1,
  66. bias: bool = True,
  67. norm_layer: str = 'batchnorm2d',
  68. norm_kwargs: Optional[Dict] = None,
  69. device=None,
  70. dtype=None,
  71. ):
  72. dd = {'device': device, 'dtype': dtype}
  73. norm_kwargs = norm_kwargs or {}
  74. super().__init__()
  75. self.conv = create_conv2d(
  76. in_channels,
  77. out_channels,
  78. kernel_size,
  79. stride=stride,
  80. padding=padding,
  81. dilation=dilation,
  82. groups=groups,
  83. bias=bias,
  84. **dd,
  85. )
  86. self.bn = create_norm_layer(norm_layer, out_channels, **norm_kwargs, **dd)
  87. def forward(self, x):
  88. x = self.conv(x)
  89. x = self.bn(x)
  90. return x
  91. class Attention2d(torch.nn.Module):
  92. attention_bias_cache: Dict[str, torch.Tensor]
  93. def __init__(
  94. self,
  95. dim: int = 384,
  96. key_dim: int = 32,
  97. num_heads: int = 8,
  98. attn_ratio: int = 4,
  99. resolution: Union[int, Tuple[int, int]] = 7,
  100. act_layer: Type[nn.Module] = nn.GELU,
  101. stride: Optional[int] = None,
  102. device=None,
  103. dtype=None,
  104. ):
  105. dd = {'device': device, 'dtype': dtype}
  106. super().__init__()
  107. self.num_heads = num_heads
  108. self.scale = key_dim ** -0.5
  109. self.key_dim = key_dim
  110. resolution = to_2tuple(resolution)
  111. if stride is not None:
  112. resolution = tuple([math.ceil(r / stride) for r in resolution])
  113. self.stride_conv = ConvNorm(dim, dim, kernel_size=3, stride=stride, groups=dim, **dd)
  114. self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear')
  115. else:
  116. self.stride_conv = None
  117. self.upsample = None
  118. self.resolution = resolution
  119. self.N = self.resolution[0] * self.resolution[1]
  120. self.d = int(attn_ratio * key_dim)
  121. self.dh = int(attn_ratio * key_dim) * num_heads
  122. self.attn_ratio = attn_ratio
  123. kh = self.key_dim * self.num_heads
  124. self.q = ConvNorm(dim, kh, **dd)
  125. self.k = ConvNorm(dim, kh, **dd)
  126. self.v = ConvNorm(dim, self.dh, **dd)
  127. self.v_local = ConvNorm(self.dh, self.dh, kernel_size=3, groups=self.dh, **dd)
  128. self.talking_head1 = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1, **dd)
  129. self.talking_head2 = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1, **dd)
  130. self.act = act_layer()
  131. self.proj = ConvNorm(self.dh, dim, 1, **dd)
  132. pos = torch.stack(ndgrid(
  133. torch.arange(self.resolution[0], device=device, dtype=torch.long),
  134. torch.arange(self.resolution[1], device=device, dtype=torch.long),
  135. )).flatten(1)
  136. rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
  137. rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]
  138. self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, self.N, **dd))
  139. self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
  140. self.attention_bias_cache = {} # per-device attention_biases cache (data-parallel compat)
  141. @torch.no_grad()
  142. def train(self, mode=True):
  143. super().train(mode)
  144. if mode and self.attention_bias_cache:
  145. self.attention_bias_cache = {} # clear ab cache
  146. def get_attention_biases(self, device: torch.device) -> torch.Tensor:
  147. if torch.jit.is_tracing() or self.training:
  148. return self.attention_biases[:, self.attention_bias_idxs]
  149. else:
  150. device_key = str(device)
  151. if device_key not in self.attention_bias_cache:
  152. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  153. return self.attention_bias_cache[device_key]
  154. def forward(self, x):
  155. B, C, H, W = x.shape
  156. if self.stride_conv is not None:
  157. x = self.stride_conv(x)
  158. q = self.q(x).reshape(B, self.num_heads, -1, self.N).permute(0, 1, 3, 2)
  159. k = self.k(x).reshape(B, self.num_heads, -1, self.N).permute(0, 1, 2, 3)
  160. v = self.v(x)
  161. v_local = self.v_local(v)
  162. v = v.reshape(B, self.num_heads, -1, self.N).permute(0, 1, 3, 2)
  163. attn = (q @ k) * self.scale
  164. attn = attn + self.get_attention_biases(x.device)
  165. attn = self.talking_head1(attn)
  166. attn = attn.softmax(dim=-1)
  167. attn = self.talking_head2(attn)
  168. x = (attn @ v).transpose(2, 3)
  169. x = x.reshape(B, self.dh, self.resolution[0], self.resolution[1]) + v_local
  170. if self.upsample is not None:
  171. x = self.upsample(x)
  172. x = self.act(x)
  173. x = self.proj(x)
  174. return x
  175. class LocalGlobalQuery(torch.nn.Module):
  176. def __init__(
  177. self,
  178. in_dim: int,
  179. out_dim: int,
  180. device=None,
  181. dtype=None,
  182. ):
  183. dd = {'device': device, 'dtype': dtype}
  184. super().__init__()
  185. self.pool = nn.AvgPool2d(1, 2, 0)
  186. self.local = nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=2, padding=1, groups=in_dim, **dd)
  187. self.proj = ConvNorm(in_dim, out_dim, 1, **dd)
  188. def forward(self, x):
  189. local_q = self.local(x)
  190. pool_q = self.pool(x)
  191. q = local_q + pool_q
  192. q = self.proj(q)
  193. return q
  194. class Attention2dDownsample(torch.nn.Module):
  195. attention_bias_cache: Dict[str, torch.Tensor]
  196. def __init__(
  197. self,
  198. dim: int = 384,
  199. key_dim: int = 16,
  200. num_heads: int = 8,
  201. attn_ratio: int = 4,
  202. resolution: Union[int, Tuple[int, int]] = 7,
  203. out_dim: Optional[int] = None,
  204. act_layer: Type[nn.Module] = nn.GELU,
  205. device=None,
  206. dtype=None,
  207. ):
  208. dd = {'device': device, 'dtype': dtype}
  209. super().__init__()
  210. self.num_heads = num_heads
  211. self.scale = key_dim ** -0.5
  212. self.key_dim = key_dim
  213. self.resolution = to_2tuple(resolution)
  214. self.resolution2 = tuple([math.ceil(r / 2) for r in self.resolution])
  215. self.N = self.resolution[0] * self.resolution[1]
  216. self.N2 = self.resolution2[0] * self.resolution2[1]
  217. self.d = int(attn_ratio * key_dim)
  218. self.dh = int(attn_ratio * key_dim) * num_heads
  219. self.attn_ratio = attn_ratio
  220. self.out_dim = out_dim or dim
  221. kh = self.key_dim * self.num_heads
  222. self.q = LocalGlobalQuery(dim, kh, **dd)
  223. self.k = ConvNorm(dim, kh, 1, **dd)
  224. self.v = ConvNorm(dim, self.dh, 1, **dd)
  225. self.v_local = ConvNorm(self.dh, self.dh, kernel_size=3, stride=2, groups=self.dh, **dd)
  226. self.act = act_layer()
  227. self.proj = ConvNorm(self.dh, self.out_dim, 1, **dd)
  228. self.attention_biases = nn.Parameter(torch.zeros(num_heads, self.N, **dd))
  229. k_pos = torch.stack(ndgrid(
  230. torch.arange(self.resolution[0], device=device, dtype=torch.long),
  231. torch.arange(self.resolution[1], device=device, dtype=torch.long),
  232. )).flatten(1)
  233. q_pos = torch.stack(ndgrid(
  234. torch.arange(0, self.resolution[0], step=2, device=device, dtype=torch.long),
  235. torch.arange(0, self.resolution[1], step=2, device=device, dtype=torch.long),
  236. )).flatten(1)
  237. rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
  238. rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]
  239. self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
  240. self.attention_bias_cache = {} # per-device attention_biases cache (data-parallel compat)
  241. @torch.no_grad()
  242. def train(self, mode=True):
  243. super().train(mode)
  244. if mode and self.attention_bias_cache:
  245. self.attention_bias_cache = {} # clear ab cache
  246. def get_attention_biases(self, device: torch.device) -> torch.Tensor:
  247. if torch.jit.is_tracing() or self.training:
  248. return self.attention_biases[:, self.attention_bias_idxs]
  249. else:
  250. device_key = str(device)
  251. if device_key not in self.attention_bias_cache:
  252. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  253. return self.attention_bias_cache[device_key]
  254. def forward(self, x):
  255. B, C, H, W = x.shape
  256. q = self.q(x).reshape(B, self.num_heads, -1, self.N2).permute(0, 1, 3, 2)
  257. k = self.k(x).reshape(B, self.num_heads, -1, self.N).permute(0, 1, 2, 3)
  258. v = self.v(x)
  259. v_local = self.v_local(v)
  260. v = v.reshape(B, self.num_heads, -1, self.N).permute(0, 1, 3, 2)
  261. attn = (q @ k) * self.scale
  262. attn = attn + self.get_attention_biases(x.device)
  263. attn = attn.softmax(dim=-1)
  264. x = (attn @ v).transpose(2, 3)
  265. x = x.reshape(B, self.dh, self.resolution2[0], self.resolution2[1]) + v_local
  266. x = self.act(x)
  267. x = self.proj(x)
  268. return x
  269. class Downsample(nn.Module):
  270. def __init__(
  271. self,
  272. in_chs: int,
  273. out_chs: int,
  274. kernel_size: Union[int, Tuple[int, int]] = 3,
  275. stride: Union[int, Tuple[int, int]] = 2,
  276. padding: Union[int, Tuple[int, int]] = 1,
  277. resolution: Union[int, Tuple[int, int]] = 7,
  278. use_attn: bool = False,
  279. act_layer: Type[nn.Module] = nn.GELU,
  280. norm_layer: Optional[Type[nn.Module]] = nn.BatchNorm2d,
  281. device=None,
  282. dtype=None,
  283. ):
  284. dd = {'device': device, 'dtype': dtype}
  285. super().__init__()
  286. kernel_size = to_2tuple(kernel_size)
  287. stride = to_2tuple(stride)
  288. padding = to_2tuple(padding)
  289. norm_layer = norm_layer or nn.Identity()
  290. self.conv = ConvNorm(
  291. in_chs,
  292. out_chs,
  293. kernel_size=kernel_size,
  294. stride=stride,
  295. padding=padding,
  296. norm_layer=norm_layer,
  297. **dd,
  298. )
  299. if use_attn:
  300. self.attn = Attention2dDownsample(
  301. dim=in_chs,
  302. out_dim=out_chs,
  303. resolution=resolution,
  304. act_layer=act_layer,
  305. **dd,
  306. )
  307. else:
  308. self.attn = None
  309. def forward(self, x):
  310. out = self.conv(x)
  311. if self.attn is not None:
  312. return self.attn(x) + out
  313. return out
  314. class ConvMlpWithNorm(nn.Module):
  315. """
  316. Implementation of MLP with 1*1 convolutions.
  317. Input: tensor with shape [B, C, H, W]
  318. """
  319. def __init__(
  320. self,
  321. in_features: int,
  322. hidden_features: Optional[int] = None,
  323. out_features: Optional[int] = None,
  324. act_layer: Type[nn.Module] = nn.GELU,
  325. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  326. drop: float = 0.,
  327. mid_conv: bool = False,
  328. device=None,
  329. dtype=None,
  330. ):
  331. dd = {'device': device, 'dtype': dtype}
  332. super().__init__()
  333. out_features = out_features or in_features
  334. hidden_features = hidden_features or in_features
  335. self.fc1 = ConvNormAct(
  336. in_features,
  337. hidden_features,
  338. 1,
  339. bias=True,
  340. norm_layer=norm_layer,
  341. act_layer=act_layer,
  342. **dd,
  343. )
  344. if mid_conv:
  345. self.mid = ConvNormAct(
  346. hidden_features,
  347. hidden_features,
  348. 3,
  349. groups=hidden_features,
  350. bias=True,
  351. norm_layer=norm_layer,
  352. act_layer=act_layer,
  353. **dd,
  354. )
  355. else:
  356. self.mid = nn.Identity()
  357. self.drop1 = nn.Dropout(drop)
  358. self.fc2 = ConvNorm(hidden_features, out_features, 1, norm_layer=norm_layer, **dd)
  359. self.drop2 = nn.Dropout(drop)
  360. def forward(self, x):
  361. x = self.fc1(x)
  362. x = self.mid(x)
  363. x = self.drop1(x)
  364. x = self.fc2(x)
  365. x = self.drop2(x)
  366. return x
  367. class EfficientFormerV2Block(nn.Module):
  368. def __init__(
  369. self,
  370. dim: int,
  371. mlp_ratio: float = 4.,
  372. act_layer: Type[nn.Module] = nn.GELU,
  373. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  374. proj_drop: float = 0.,
  375. drop_path: float = 0.,
  376. layer_scale_init_value: Optional[float] = 1e-5,
  377. resolution: Union[int, Tuple[int, int]] = 7,
  378. stride: Optional[int] = None,
  379. use_attn: bool = True,
  380. device=None,
  381. dtype=None,
  382. ):
  383. dd = {'device': device, 'dtype': dtype}
  384. super().__init__()
  385. if use_attn:
  386. self.token_mixer = Attention2d(
  387. dim,
  388. resolution=resolution,
  389. act_layer=act_layer,
  390. stride=stride,
  391. **dd,
  392. )
  393. self.ls1 = LayerScale2d(
  394. dim, layer_scale_init_value, **dd) if layer_scale_init_value is not None else nn.Identity()
  395. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  396. else:
  397. self.token_mixer = None
  398. self.ls1 = None
  399. self.drop_path1 = None
  400. self.mlp = ConvMlpWithNorm(
  401. in_features=dim,
  402. hidden_features=int(dim * mlp_ratio),
  403. act_layer=act_layer,
  404. norm_layer=norm_layer,
  405. drop=proj_drop,
  406. mid_conv=True,
  407. **dd,
  408. )
  409. self.ls2 = LayerScale2d(
  410. dim, layer_scale_init_value, **dd) if layer_scale_init_value is not None else nn.Identity()
  411. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  412. def forward(self, x):
  413. if self.token_mixer is not None:
  414. x = x + self.drop_path1(self.ls1(self.token_mixer(x)))
  415. x = x + self.drop_path2(self.ls2(self.mlp(x)))
  416. return x
  417. class Stem4(nn.Sequential):
  418. def __init__(
  419. self,
  420. in_chs: int,
  421. out_chs: int,
  422. act_layer: Type[nn.Module] = nn.GELU,
  423. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  424. device=None,
  425. dtype=None,
  426. ):
  427. dd = {'device': device, 'dtype': dtype}
  428. super().__init__()
  429. self.stride = 4
  430. self.conv1 = ConvNormAct(
  431. in_chs,
  432. out_chs // 2,
  433. kernel_size=3,
  434. stride=2, padding=1,
  435. bias=True,
  436. norm_layer=norm_layer,
  437. act_layer=act_layer,
  438. **dd,
  439. )
  440. self.conv2 = ConvNormAct(
  441. out_chs // 2,
  442. out_chs,
  443. kernel_size=3,
  444. stride=2,
  445. padding=1,
  446. bias=True,
  447. norm_layer=norm_layer,
  448. act_layer=act_layer,
  449. **dd,
  450. )
  451. class EfficientFormerV2Stage(nn.Module):
  452. def __init__(
  453. self,
  454. dim: int,
  455. dim_out: int,
  456. depth: int,
  457. resolution: Union[int, Tuple[int, int]] = 7,
  458. downsample: bool = True,
  459. block_stride: Optional[int] = None,
  460. downsample_use_attn: bool = False,
  461. block_use_attn: bool = False,
  462. num_vit: int = 1,
  463. mlp_ratio: Union[float, Tuple[float, ...]] = 4.,
  464. proj_drop: float = .0,
  465. drop_path: Union[float, List[float]] = 0.,
  466. layer_scale_init_value: Optional[float] = 1e-5,
  467. act_layer: Type[nn.Module] = nn.GELU,
  468. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  469. device=None,
  470. dtype=None,
  471. ):
  472. dd = {'device': device, 'dtype': dtype}
  473. super().__init__()
  474. self.grad_checkpointing = False
  475. mlp_ratio = to_ntuple(depth)(mlp_ratio)
  476. resolution = to_2tuple(resolution)
  477. if downsample:
  478. self.downsample = Downsample(
  479. dim,
  480. dim_out,
  481. use_attn=downsample_use_attn,
  482. resolution=resolution,
  483. norm_layer=norm_layer,
  484. act_layer=act_layer,
  485. **dd,
  486. )
  487. dim = dim_out
  488. resolution = tuple([math.ceil(r / 2) for r in resolution])
  489. else:
  490. assert dim == dim_out
  491. self.downsample = nn.Identity()
  492. blocks = []
  493. for block_idx in range(depth):
  494. remain_idx = depth - num_vit - 1
  495. b = EfficientFormerV2Block(
  496. dim,
  497. resolution=resolution,
  498. stride=block_stride,
  499. mlp_ratio=mlp_ratio[block_idx],
  500. use_attn=block_use_attn and block_idx > remain_idx,
  501. proj_drop=proj_drop,
  502. drop_path=drop_path[block_idx],
  503. layer_scale_init_value=layer_scale_init_value,
  504. act_layer=act_layer,
  505. norm_layer=norm_layer,
  506. **dd,
  507. )
  508. blocks += [b]
  509. self.blocks = nn.Sequential(*blocks)
  510. def forward(self, x):
  511. x = self.downsample(x)
  512. if self.grad_checkpointing and not torch.jit.is_scripting():
  513. x = checkpoint_seq(self.blocks, x)
  514. else:
  515. x = self.blocks(x)
  516. return x
  517. class EfficientFormerV2(nn.Module):
  518. def __init__(
  519. self,
  520. depths: Tuple[int, ...],
  521. in_chans: int = 3,
  522. img_size: Union[int, Tuple[int, int]] = 224,
  523. global_pool: str = 'avg',
  524. embed_dims: Optional[Tuple[int, ...]] = None,
  525. downsamples: Optional[Tuple[bool, ...]] = None,
  526. mlp_ratios: Union[float, Tuple[float, ...], Tuple[Tuple[float, ...], ...]] = 4,
  527. norm_layer: str = 'batchnorm2d',
  528. norm_eps: float = 1e-5,
  529. act_layer: str = 'gelu',
  530. num_classes: int = 1000,
  531. drop_rate: float = 0.,
  532. proj_drop_rate: float = 0.,
  533. drop_path_rate: float = 0.,
  534. layer_scale_init_value: Optional[float] = 1e-5,
  535. num_vit: int = 0,
  536. distillation: bool = True,
  537. device=None,
  538. dtype=None,
  539. ):
  540. super().__init__()
  541. dd = {'device': device, 'dtype': dtype}
  542. assert global_pool in ('avg', '')
  543. self.num_classes = num_classes
  544. self.global_pool = global_pool
  545. self.feature_info = []
  546. img_size = to_2tuple(img_size)
  547. norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
  548. act_layer = get_act_layer(act_layer)
  549. self.stem = Stem4(in_chans, embed_dims[0], act_layer=act_layer, norm_layer=norm_layer, **dd)
  550. prev_dim = embed_dims[0]
  551. stride = 4
  552. num_stages = len(depths)
  553. dpr = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  554. downsamples = downsamples or (False,) + (True,) * (len(depths) - 1)
  555. mlp_ratios = to_ntuple(num_stages)(mlp_ratios)
  556. stages = []
  557. for i in range(num_stages):
  558. curr_resolution = tuple([math.ceil(s / stride) for s in img_size])
  559. stage = EfficientFormerV2Stage(
  560. prev_dim,
  561. embed_dims[i],
  562. depth=depths[i],
  563. resolution=curr_resolution,
  564. downsample=downsamples[i],
  565. block_stride=2 if i == 2 else None,
  566. downsample_use_attn=i >= 3,
  567. block_use_attn=i >= 2,
  568. num_vit=num_vit,
  569. mlp_ratio=mlp_ratios[i],
  570. proj_drop=proj_drop_rate,
  571. drop_path=dpr[i],
  572. layer_scale_init_value=layer_scale_init_value,
  573. act_layer=act_layer,
  574. norm_layer=norm_layer,
  575. **dd,
  576. )
  577. if downsamples[i]:
  578. stride *= 2
  579. prev_dim = embed_dims[i]
  580. self.feature_info += [dict(num_chs=prev_dim, reduction=stride, module=f'stages.{i}')]
  581. stages.append(stage)
  582. self.stages = nn.Sequential(*stages)
  583. # Classifier head
  584. self.num_features = self.head_hidden_size = embed_dims[-1]
  585. self.norm = norm_layer(embed_dims[-1], **dd)
  586. self.head_drop = nn.Dropout(drop_rate)
  587. self.head = nn.Linear(embed_dims[-1], num_classes, **dd) if num_classes > 0 else nn.Identity()
  588. self.dist = distillation
  589. if self.dist:
  590. self.head_dist = nn.Linear(embed_dims[-1], num_classes, **dd) if num_classes > 0 else nn.Identity()
  591. else:
  592. self.head_dist = None
  593. self.apply(self.init_weights)
  594. self.distilled_training = False
  595. # init for classification
  596. def init_weights(self, m):
  597. if isinstance(m, nn.Linear):
  598. trunc_normal_(m.weight, std=.02)
  599. if m.bias is not None:
  600. nn.init.constant_(m.bias, 0)
  601. @torch.jit.ignore
  602. def no_weight_decay(self):
  603. return {k for k, _ in self.named_parameters() if 'attention_biases' in k}
  604. @torch.jit.ignore
  605. def group_matcher(self, coarse=False):
  606. matcher = dict(
  607. stem=r'^stem', # stem and embed
  608. blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
  609. )
  610. return matcher
  611. @torch.jit.ignore
  612. def set_grad_checkpointing(self, enable=True):
  613. for s in self.stages:
  614. s.grad_checkpointing = enable
  615. @torch.jit.ignore
  616. def get_classifier(self) -> nn.Module:
  617. return self.head, self.head_dist
  618. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  619. self.num_classes = num_classes
  620. if global_pool is not None:
  621. self.global_pool = global_pool
  622. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  623. self.head_dist = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  624. @torch.jit.ignore
  625. def set_distilled_training(self, enable=True):
  626. self.distilled_training = enable
  627. def forward_intermediates(
  628. self,
  629. x: torch.Tensor,
  630. indices: Optional[Union[int, List[int]]] = None,
  631. norm: bool = False,
  632. stop_early: bool = False,
  633. output_fmt: str = 'NCHW',
  634. intermediates_only: bool = False,
  635. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  636. """ Forward features that returns intermediates.
  637. Args:
  638. x: Input image tensor
  639. indices: Take last n blocks if int, all if None, select matching indices if sequence
  640. norm: Apply norm layer to compatible intermediates
  641. stop_early: Stop iterating over blocks when last desired intermediate hit
  642. output_fmt: Shape of intermediate feature outputs
  643. intermediates_only: Only return intermediate features
  644. Returns:
  645. """
  646. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  647. intermediates = []
  648. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  649. # forward pass
  650. x = self.stem(x)
  651. last_idx = len(self.stages) - 1
  652. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  653. stages = self.stages
  654. else:
  655. stages = self.stages[:max_index + 1]
  656. for feat_idx, stage in enumerate(stages):
  657. x = stage(x)
  658. if feat_idx in take_indices:
  659. if feat_idx == last_idx:
  660. x_inter = self.norm(x) if norm else x
  661. intermediates.append(x_inter)
  662. else:
  663. intermediates.append(x)
  664. if intermediates_only:
  665. return intermediates
  666. if feat_idx == last_idx:
  667. x = self.norm(x)
  668. return x, intermediates
  669. def prune_intermediate_layers(
  670. self,
  671. indices: Union[int, List[int]] = 1,
  672. prune_norm: bool = False,
  673. prune_head: bool = True,
  674. ):
  675. """ Prune layers not required for specified intermediates.
  676. """
  677. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  678. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  679. if prune_norm:
  680. self.norm = nn.Identity()
  681. if prune_head:
  682. self.reset_classifier(0, '')
  683. return take_indices
  684. def forward_features(self, x):
  685. x = self.stem(x)
  686. x = self.stages(x)
  687. x = self.norm(x)
  688. return x
  689. def forward_head(self, x, pre_logits: bool = False):
  690. if self.global_pool == 'avg':
  691. x = x.mean(dim=(2, 3))
  692. x = self.head_drop(x)
  693. if pre_logits:
  694. return x
  695. x, x_dist = self.head(x), self.head_dist(x)
  696. if self.distilled_training and self.training and not torch.jit.is_scripting():
  697. # only return separate classification predictions when training in distilled mode
  698. return x, x_dist
  699. else:
  700. # during standard train/finetune, inference average the classifier predictions
  701. return (x + x_dist) / 2
  702. def forward(self, x):
  703. x = self.forward_features(x)
  704. x = self.forward_head(x)
  705. return x
  706. def _cfg(url='', **kwargs):
  707. return {
  708. 'url': url,
  709. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'fixed_input_size': True,
  710. 'crop_pct': .95, 'interpolation': 'bicubic',
  711. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  712. 'classifier': ('head', 'head_dist'), 'first_conv': 'stem.conv1.conv',
  713. 'license': 'apache-2.0',
  714. **kwargs
  715. }
  716. default_cfgs = generate_default_cfgs({
  717. 'efficientformerv2_s0.snap_dist_in1k': _cfg(
  718. hf_hub_id='timm/',
  719. ),
  720. 'efficientformerv2_s1.snap_dist_in1k': _cfg(
  721. hf_hub_id='timm/',
  722. ),
  723. 'efficientformerv2_s2.snap_dist_in1k': _cfg(
  724. hf_hub_id='timm/',
  725. ),
  726. 'efficientformerv2_l.snap_dist_in1k': _cfg(
  727. hf_hub_id='timm/',
  728. ),
  729. })
  730. def _create_efficientformerv2(variant, pretrained=False, **kwargs):
  731. out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
  732. model = build_model_with_cfg(
  733. EfficientFormerV2, variant, pretrained,
  734. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  735. **kwargs)
  736. return model
  737. @register_model
  738. def efficientformerv2_s0(pretrained=False, **kwargs) -> EfficientFormerV2:
  739. model_args = dict(
  740. depths=EfficientFormer_depth['S0'],
  741. embed_dims=EfficientFormer_width['S0'],
  742. num_vit=2,
  743. drop_path_rate=0.0,
  744. mlp_ratios=EfficientFormer_expansion_ratios['S0'],
  745. )
  746. return _create_efficientformerv2('efficientformerv2_s0', pretrained=pretrained, **dict(model_args, **kwargs))
  747. @register_model
  748. def efficientformerv2_s1(pretrained=False, **kwargs) -> EfficientFormerV2:
  749. model_args = dict(
  750. depths=EfficientFormer_depth['S1'],
  751. embed_dims=EfficientFormer_width['S1'],
  752. num_vit=2,
  753. drop_path_rate=0.0,
  754. mlp_ratios=EfficientFormer_expansion_ratios['S1'],
  755. )
  756. return _create_efficientformerv2('efficientformerv2_s1', pretrained=pretrained, **dict(model_args, **kwargs))
  757. @register_model
  758. def efficientformerv2_s2(pretrained=False, **kwargs) -> EfficientFormerV2:
  759. model_args = dict(
  760. depths=EfficientFormer_depth['S2'],
  761. embed_dims=EfficientFormer_width['S2'],
  762. num_vit=4,
  763. drop_path_rate=0.02,
  764. mlp_ratios=EfficientFormer_expansion_ratios['S2'],
  765. )
  766. return _create_efficientformerv2('efficientformerv2_s2', pretrained=pretrained, **dict(model_args, **kwargs))
  767. @register_model
  768. def efficientformerv2_l(pretrained=False, **kwargs) -> EfficientFormerV2:
  769. model_args = dict(
  770. depths=EfficientFormer_depth['L'],
  771. embed_dims=EfficientFormer_width['L'],
  772. num_vit=6,
  773. drop_path_rate=0.1,
  774. mlp_ratios=EfficientFormer_expansion_ratios['L'],
  775. )
  776. return _create_efficientformerv2('efficientformerv2_l', pretrained=pretrained, **dict(model_args, **kwargs))