edgenext.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705
  1. """ EdgeNeXt
  2. Paper: `EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications`
  3. - https://arxiv.org/abs/2206.10589
  4. Original code and weights from https://github.com/mmaaz60/EdgeNeXt
  5. Modifications and additions for timm by / Copyright 2022, Ross Wightman
  6. """
  7. import math
  8. from functools import partial
  9. from typing import List, Optional, Tuple, Type, Union
  10. import torch
  11. import torch.nn.functional as F
  12. from torch import nn
  13. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  14. from timm.layers import (
  15. DropPath,
  16. calculate_drop_path_rates,
  17. LayerNorm2d,
  18. Mlp,
  19. create_conv2d,
  20. NormMlpClassifierHead,
  21. ClassifierHead,
  22. trunc_normal_tf_,
  23. )
  24. from ._builder import build_model_with_cfg
  25. from ._features import feature_take_indices
  26. from ._features_fx import register_notrace_module
  27. from ._manipulate import named_apply, checkpoint_seq
  28. from ._registry import register_model, generate_default_cfgs
  29. __all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this
  30. @register_notrace_module # reason: FX can't symbolically trace torch.arange in forward method
  31. class PositionalEncodingFourier(nn.Module):
  32. def __init__(
  33. self,
  34. hidden_dim: int = 32,
  35. dim: int = 768,
  36. temperature: float = 10000.,
  37. device=None,
  38. dtype=None,
  39. ):
  40. dd = {'device': device, 'dtype': dtype}
  41. super().__init__()
  42. self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1, **dd)
  43. self.scale = 2 * math.pi
  44. self.temperature = temperature
  45. self.hidden_dim = hidden_dim
  46. self.dim = dim
  47. def forward(self, shape: Tuple[int, int, int]):
  48. device = self.token_projection.weight.device
  49. dtype = self.token_projection.weight.dtype
  50. inv_mask = ~torch.zeros(shape).to(device=device, dtype=torch.bool)
  51. y_embed = inv_mask.cumsum(1, dtype=torch.float32)
  52. x_embed = inv_mask.cumsum(2, dtype=torch.float32)
  53. eps = 1e-6
  54. y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
  55. x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
  56. dim_t = torch.arange(self.hidden_dim, dtype=torch.int64, device=device).to(torch.float32)
  57. dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
  58. pos_x = x_embed[:, :, :, None] / dim_t
  59. pos_y = y_embed[:, :, :, None] / dim_t
  60. pos_x = torch.stack(
  61. (pos_x[:, :, :, 0::2].sin(),
  62. pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
  63. pos_y = torch.stack(
  64. (pos_y[:, :, :, 0::2].sin(),
  65. pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
  66. pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  67. pos = self.token_projection(pos.to(dtype))
  68. return pos
  69. class ConvBlock(nn.Module):
  70. def __init__(
  71. self,
  72. dim: int,
  73. dim_out: Optional[int] = None,
  74. kernel_size: int = 7,
  75. stride: int = 1,
  76. conv_bias: bool = True,
  77. expand_ratio: float = 4,
  78. ls_init_value: float = 1e-6,
  79. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  80. act_layer: Type[nn.Module] = nn.GELU,
  81. drop_path: float = 0.,
  82. device=None,
  83. dtype=None,
  84. ):
  85. dd = {'device': device, 'dtype': dtype}
  86. super().__init__()
  87. dim_out = dim_out or dim
  88. self.shortcut_after_dw = stride > 1 or dim != dim_out
  89. self.conv_dw = create_conv2d(
  90. dim,
  91. dim_out,
  92. kernel_size=kernel_size,
  93. stride=stride,
  94. depthwise=True,
  95. bias=conv_bias,
  96. **dd,
  97. )
  98. self.norm = norm_layer(dim_out, **dd)
  99. self.mlp = Mlp(
  100. dim_out,
  101. int(expand_ratio * dim_out),
  102. act_layer=act_layer,
  103. **dd,
  104. )
  105. self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out, **dd)) if ls_init_value > 0 else None
  106. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  107. def forward(self, x):
  108. shortcut = x
  109. x = self.conv_dw(x)
  110. if self.shortcut_after_dw:
  111. shortcut = x
  112. x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
  113. x = self.norm(x)
  114. x = self.mlp(x)
  115. if self.gamma is not None:
  116. x = self.gamma * x
  117. x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
  118. x = shortcut + self.drop_path(x)
  119. return x
  120. class CrossCovarianceAttn(nn.Module):
  121. def __init__(
  122. self,
  123. dim: int,
  124. num_heads: int = 8,
  125. qkv_bias: bool = False,
  126. attn_drop: float = 0.,
  127. proj_drop: float = 0.,
  128. device=None,
  129. dtype=None,
  130. ):
  131. dd = {'device': device, 'dtype': dtype}
  132. super().__init__()
  133. self.num_heads = num_heads
  134. self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1, **dd))
  135. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  136. self.attn_drop = nn.Dropout(attn_drop)
  137. self.proj = nn.Linear(dim, dim, **dd)
  138. self.proj_drop = nn.Dropout(proj_drop)
  139. def forward(self, x):
  140. B, N, C = x.shape
  141. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 4, 1)
  142. q, k, v = qkv.unbind(0)
  143. # NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map
  144. attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) * self.temperature
  145. attn = attn.softmax(dim=-1)
  146. attn = self.attn_drop(attn)
  147. x = (attn @ v)
  148. x = x.permute(0, 3, 1, 2).reshape(B, N, C)
  149. x = self.proj(x)
  150. x = self.proj_drop(x)
  151. return x
  152. @torch.jit.ignore
  153. def no_weight_decay(self):
  154. return {'temperature'}
  155. class SplitTransposeBlock(nn.Module):
  156. def __init__(
  157. self,
  158. dim: int,
  159. num_scales: int = 1,
  160. num_heads: int = 8,
  161. expand_ratio: float = 4,
  162. use_pos_emb: bool = True,
  163. conv_bias: bool = True,
  164. qkv_bias: bool = True,
  165. ls_init_value: float = 1e-6,
  166. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  167. act_layer: Type[nn.Module] = nn.GELU,
  168. drop_path: float = 0.,
  169. attn_drop: float = 0.,
  170. proj_drop: float = 0.,
  171. device=None,
  172. dtype=None,
  173. ):
  174. dd = {'device': device, 'dtype': dtype}
  175. super().__init__()
  176. width = max(int(math.ceil(dim / num_scales)), int(math.floor(dim // num_scales)))
  177. self.width = width
  178. self.num_scales = max(1, num_scales - 1)
  179. convs = []
  180. for i in range(self.num_scales):
  181. convs.append(create_conv2d(width, width, kernel_size=3, depthwise=True, bias=conv_bias, **dd))
  182. self.convs = nn.ModuleList(convs)
  183. self.pos_embd = None
  184. if use_pos_emb:
  185. self.pos_embd = PositionalEncodingFourier(dim=dim, **dd)
  186. self.norm_xca = norm_layer(dim, **dd)
  187. self.gamma_xca = nn.Parameter(ls_init_value * torch.ones(dim, **dd)) if ls_init_value > 0 else None
  188. self.xca = CrossCovarianceAttn(
  189. dim,
  190. num_heads=num_heads,
  191. qkv_bias=qkv_bias,
  192. attn_drop=attn_drop,
  193. proj_drop=proj_drop,
  194. **dd,
  195. )
  196. self.norm = norm_layer(dim, eps=1e-6, **dd)
  197. self.mlp = Mlp(
  198. dim,
  199. int(expand_ratio * dim),
  200. act_layer=act_layer,
  201. **dd,
  202. )
  203. self.gamma = nn.Parameter(ls_init_value * torch.ones(dim, **dd)) if ls_init_value > 0 else None
  204. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  205. def forward(self, x):
  206. shortcut = x
  207. # scales code re-written for torchscript as per my res2net fixes -rw
  208. # NOTE torch.split(x, self.width, 1) causing issues with ONNX export
  209. spx = x.chunk(len(self.convs) + 1, dim=1)
  210. spo = []
  211. sp = spx[0]
  212. for i, conv in enumerate(self.convs):
  213. if i > 0:
  214. sp = sp + spx[i]
  215. sp = conv(sp)
  216. spo.append(sp)
  217. spo.append(spx[-1])
  218. x = torch.cat(spo, 1)
  219. # XCA
  220. B, C, H, W = x.shape
  221. x = x.reshape(B, C, H * W).permute(0, 2, 1)
  222. if self.pos_embd is not None:
  223. pos_encoding = self.pos_embd((B, H, W)).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
  224. x = x + pos_encoding
  225. x = x + self.drop_path(self.gamma_xca * self.xca(self.norm_xca(x)))
  226. x = x.reshape(B, H, W, C)
  227. # Inverted Bottleneck
  228. x = self.norm(x)
  229. x = self.mlp(x)
  230. if self.gamma is not None:
  231. x = self.gamma * x
  232. x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
  233. x = shortcut + self.drop_path(x)
  234. return x
  235. class EdgeNeXtStage(nn.Module):
  236. def __init__(
  237. self,
  238. in_chs: int,
  239. out_chs: int,
  240. stride: int = 2,
  241. depth: int = 2,
  242. num_global_blocks: int = 1,
  243. num_heads: int = 4,
  244. scales: int = 2,
  245. kernel_size: int = 7,
  246. expand_ratio: float = 4,
  247. use_pos_emb: bool = False,
  248. downsample_block: bool = False,
  249. conv_bias: float = True,
  250. ls_init_value: float = 1.0,
  251. drop_path_rates: Optional[List[float]] = None,
  252. norm_layer: Type[nn.Module] = LayerNorm2d,
  253. norm_layer_cl: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  254. act_layer: Type[nn.Module] = nn.GELU,
  255. device=None,
  256. dtype=None,
  257. ):
  258. dd = {'device': device, 'dtype': dtype}
  259. super().__init__()
  260. self.grad_checkpointing = False
  261. if downsample_block or stride == 1:
  262. self.downsample = nn.Identity()
  263. else:
  264. self.downsample = nn.Sequential(
  265. norm_layer(in_chs, **dd),
  266. nn.Conv2d(in_chs, out_chs, kernel_size=2, stride=2, bias=conv_bias, **dd)
  267. )
  268. in_chs = out_chs
  269. stage_blocks = []
  270. for i in range(depth):
  271. if i < depth - num_global_blocks:
  272. stage_blocks.append(
  273. ConvBlock(
  274. dim=in_chs,
  275. dim_out=out_chs,
  276. stride=stride if downsample_block and i == 0 else 1,
  277. conv_bias=conv_bias,
  278. kernel_size=kernel_size,
  279. expand_ratio=expand_ratio,
  280. ls_init_value=ls_init_value,
  281. drop_path=drop_path_rates[i],
  282. norm_layer=norm_layer_cl,
  283. act_layer=act_layer,
  284. **dd,
  285. )
  286. )
  287. else:
  288. stage_blocks.append(
  289. SplitTransposeBlock(
  290. dim=in_chs,
  291. num_scales=scales,
  292. num_heads=num_heads,
  293. expand_ratio=expand_ratio,
  294. use_pos_emb=use_pos_emb,
  295. conv_bias=conv_bias,
  296. ls_init_value=ls_init_value,
  297. drop_path=drop_path_rates[i],
  298. norm_layer=norm_layer_cl,
  299. act_layer=act_layer,
  300. **dd,
  301. )
  302. )
  303. in_chs = out_chs
  304. self.blocks = nn.Sequential(*stage_blocks)
  305. def forward(self, x):
  306. x = self.downsample(x)
  307. if self.grad_checkpointing and not torch.jit.is_scripting():
  308. x = checkpoint_seq(self.blocks, x)
  309. else:
  310. x = self.blocks(x)
  311. return x
  312. class EdgeNeXt(nn.Module):
  313. def __init__(
  314. self,
  315. in_chans: int = 3,
  316. num_classes: int = 1000,
  317. global_pool: str = 'avg',
  318. dims: Tuple[int, ...] = (24, 48, 88, 168),
  319. depths: Tuple[int, ...] = (3, 3, 9, 3),
  320. global_block_counts: Tuple[int, ...] = (0, 1, 1, 1),
  321. kernel_sizes: Tuple[int, ...] = (3, 5, 7, 9),
  322. heads: Tuple[int, ...] = (8, 8, 8, 8),
  323. d2_scales: Tuple[int, ...] = (2, 2, 3, 4),
  324. use_pos_emb: Tuple[bool, ...] = (False, True, False, False),
  325. ls_init_value: float = 1e-6,
  326. head_init_scale: float = 1.,
  327. expand_ratio: float = 4,
  328. downsample_block: bool = False,
  329. conv_bias: bool = True,
  330. stem_type: str = 'patch',
  331. head_norm_first: bool = False,
  332. act_layer: Type[nn.Module] = nn.GELU,
  333. drop_path_rate: float = 0.,
  334. drop_rate: float = 0.,
  335. device=None,
  336. dtype=None,
  337. ):
  338. super().__init__()
  339. dd = {'device': device, 'dtype': dtype}
  340. self.num_classes = num_classes
  341. self.global_pool = global_pool
  342. self.drop_rate = drop_rate
  343. norm_layer = partial(LayerNorm2d, eps=1e-6)
  344. norm_layer_cl = partial(nn.LayerNorm, eps=1e-6)
  345. self.feature_info = []
  346. assert stem_type in ('patch', 'overlap')
  347. if stem_type == 'patch':
  348. self.stem = nn.Sequential(
  349. nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, bias=conv_bias, **dd,),
  350. norm_layer(dims[0], **dd),
  351. )
  352. else:
  353. self.stem = nn.Sequential(
  354. nn.Conv2d(in_chans, dims[0], kernel_size=9, stride=4, padding=9 // 2, bias=conv_bias, **dd),
  355. norm_layer(dims[0], **dd),
  356. )
  357. curr_stride = 4
  358. stages = []
  359. dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  360. in_chs = dims[0]
  361. for i in range(4):
  362. stride = 2 if curr_stride == 2 or i > 0 else 1
  363. # FIXME support dilation / output_stride
  364. curr_stride *= stride
  365. stages.append(EdgeNeXtStage(
  366. in_chs=in_chs,
  367. out_chs=dims[i],
  368. stride=stride,
  369. depth=depths[i],
  370. num_global_blocks=global_block_counts[i],
  371. num_heads=heads[i],
  372. drop_path_rates=dp_rates[i],
  373. scales=d2_scales[i],
  374. expand_ratio=expand_ratio,
  375. kernel_size=kernel_sizes[i],
  376. use_pos_emb=use_pos_emb[i],
  377. ls_init_value=ls_init_value,
  378. downsample_block=downsample_block,
  379. conv_bias=conv_bias,
  380. norm_layer=norm_layer,
  381. norm_layer_cl=norm_layer_cl,
  382. act_layer=act_layer,
  383. **dd,
  384. ))
  385. # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
  386. in_chs = dims[i]
  387. self.feature_info += [dict(num_chs=in_chs, reduction=curr_stride, module=f'stages.{i}')]
  388. self.stages = nn.Sequential(*stages)
  389. self.num_features = self.head_hidden_size = dims[-1]
  390. if head_norm_first:
  391. self.norm_pre = norm_layer(self.num_features, **dd)
  392. self.head = ClassifierHead(
  393. self.num_features,
  394. num_classes,
  395. pool_type=global_pool,
  396. drop_rate=self.drop_rate,
  397. **dd,
  398. )
  399. else:
  400. self.norm_pre = nn.Identity()
  401. self.head = NormMlpClassifierHead(
  402. self.num_features,
  403. num_classes,
  404. pool_type=global_pool,
  405. drop_rate=self.drop_rate,
  406. norm_layer=norm_layer,
  407. **dd,
  408. )
  409. named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
  410. @torch.jit.ignore
  411. def group_matcher(self, coarse=False):
  412. return dict(
  413. stem=r'^stem',
  414. blocks=r'^stages\.(\d+)' if coarse else [
  415. (r'^stages\.(\d+)\.downsample', (0,)), # blocks
  416. (r'^stages\.(\d+)\.blocks\.(\d+)', None),
  417. (r'^norm_pre', (99999,))
  418. ]
  419. )
  420. @torch.jit.ignore
  421. def set_grad_checkpointing(self, enable=True):
  422. for s in self.stages:
  423. s.grad_checkpointing = enable
  424. @torch.jit.ignore
  425. def get_classifier(self) -> nn.Module:
  426. return self.head.fc
  427. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  428. self.num_classes = num_classes
  429. self.head.reset(num_classes, global_pool)
  430. def forward_intermediates(
  431. self,
  432. x: torch.Tensor,
  433. indices: Optional[Union[int, List[int]]] = None,
  434. norm: bool = False,
  435. stop_early: bool = False,
  436. output_fmt: str = 'NCHW',
  437. intermediates_only: bool = False,
  438. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  439. """ Forward features that returns intermediates.
  440. Args:
  441. x: Input image tensor
  442. indices: Take last n blocks if int, all if None, select matching indices if sequence
  443. norm: Apply norm layer to compatible intermediates
  444. stop_early: Stop iterating over blocks when last desired intermediate hit
  445. output_fmt: Shape of intermediate feature outputs
  446. intermediates_only: Only return intermediate features
  447. Returns:
  448. """
  449. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  450. intermediates = []
  451. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  452. # forward pass
  453. x = self.stem(x)
  454. last_idx = len(self.stages) - 1
  455. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  456. stages = self.stages
  457. else:
  458. stages = self.stages[:max_index + 1]
  459. for feat_idx, stage in enumerate(stages):
  460. x = stage(x)
  461. if feat_idx in take_indices:
  462. if norm and feat_idx == last_idx:
  463. x_inter = self.norm_pre(x) # applying final norm to last intermediate
  464. else:
  465. x_inter = x
  466. intermediates.append(x_inter)
  467. if intermediates_only:
  468. return intermediates
  469. if feat_idx == last_idx:
  470. x = self.norm_pre(x)
  471. return x, intermediates
  472. def prune_intermediate_layers(
  473. self,
  474. indices: Union[int, List[int]] = 1,
  475. prune_norm: bool = False,
  476. prune_head: bool = True,
  477. ):
  478. """ Prune layers not required for specified intermediates.
  479. """
  480. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  481. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  482. if prune_norm:
  483. self.norm_pre = nn.Identity()
  484. if prune_head:
  485. self.reset_classifier(0, '')
  486. return take_indices
  487. def forward_features(self, x):
  488. x = self.stem(x)
  489. x = self.stages(x)
  490. x = self.norm_pre(x)
  491. return x
  492. def forward_head(self, x, pre_logits: bool = False):
  493. return self.head(x, pre_logits=True) if pre_logits else self.head(x)
  494. def forward(self, x):
  495. x = self.forward_features(x)
  496. x = self.forward_head(x)
  497. return x
  498. def _init_weights(module, name=None, head_init_scale=1.0):
  499. if isinstance(module, nn.Conv2d):
  500. trunc_normal_tf_(module.weight, std=.02)
  501. if module.bias is not None:
  502. nn.init.zeros_(module.bias)
  503. elif isinstance(module, nn.Linear):
  504. trunc_normal_tf_(module.weight, std=.02)
  505. nn.init.zeros_(module.bias)
  506. if name and 'head.' in name:
  507. module.weight.data.mul_(head_init_scale)
  508. module.bias.data.mul_(head_init_scale)
  509. def checkpoint_filter_fn(state_dict, model):
  510. """ Remap FB checkpoints -> timm """
  511. if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
  512. return state_dict # non-FB checkpoint
  513. # models were released as train checkpoints... :/
  514. if 'model_ema' in state_dict:
  515. state_dict = state_dict['model_ema']
  516. elif 'model' in state_dict:
  517. state_dict = state_dict['model']
  518. elif 'state_dict' in state_dict:
  519. state_dict = state_dict['state_dict']
  520. out_dict = {}
  521. import re
  522. for k, v in state_dict.items():
  523. k = k.replace('downsample_layers.0.', 'stem.')
  524. k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
  525. k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
  526. k = k.replace('dwconv', 'conv_dw')
  527. k = k.replace('pwconv', 'mlp.fc')
  528. k = k.replace('head.', 'head.fc.')
  529. if k.startswith('norm.'):
  530. k = k.replace('norm', 'head.norm')
  531. if v.ndim == 2 and 'head' not in k:
  532. model_shape = model.state_dict()[k].shape
  533. v = v.reshape(model_shape)
  534. out_dict[k] = v
  535. return out_dict
  536. def _create_edgenext(variant, pretrained=False, **kwargs):
  537. model = build_model_with_cfg(
  538. EdgeNeXt, variant, pretrained,
  539. pretrained_filter_fn=checkpoint_filter_fn,
  540. feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
  541. **kwargs)
  542. return model
  543. def _cfg(url='', **kwargs):
  544. return {
  545. 'url': url,
  546. 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
  547. 'crop_pct': 0.9, 'interpolation': 'bicubic',
  548. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  549. 'first_conv': 'stem.0', 'classifier': 'head.fc',
  550. 'license': 'mit',
  551. **kwargs
  552. }
  553. default_cfgs = generate_default_cfgs({
  554. 'edgenext_xx_small.in1k': _cfg(
  555. hf_hub_id='timm/',
  556. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  557. 'edgenext_x_small.in1k': _cfg(
  558. hf_hub_id='timm/',
  559. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  560. 'edgenext_small.usi_in1k': _cfg( # USI weights
  561. hf_hub_id='timm/',
  562. crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
  563. ),
  564. 'edgenext_base.usi_in1k': _cfg( # USI weights
  565. hf_hub_id='timm/',
  566. crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
  567. ),
  568. 'edgenext_base.in21k_ft_in1k': _cfg( # USI weights
  569. hf_hub_id='timm/',
  570. crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
  571. ),
  572. 'edgenext_small_rw.sw_in1k': _cfg(
  573. hf_hub_id='timm/',
  574. test_input_size=(3, 320, 320), test_crop_pct=1.0,
  575. ),
  576. })
  577. @register_model
  578. def edgenext_xx_small(pretrained=False, **kwargs) -> EdgeNeXt:
  579. # 1.33M & 260.58M @ 256 resolution
  580. # 71.23% Top-1 accuracy
  581. # No AA, Color Jitter=0.4, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
  582. # Jetson FPS=51.66 versus 47.67 for MobileViT_XXS
  583. # For A100: FPS @ BS=1: 212.13 & @ BS=256: 7042.06 versus FPS @ BS=1: 96.68 & @ BS=256: 4624.71 for MobileViT_XXS
  584. model_args = dict(depths=(2, 2, 6, 2), dims=(24, 48, 88, 168), heads=(4, 4, 4, 4))
  585. return _create_edgenext('edgenext_xx_small', pretrained=pretrained, **dict(model_args, **kwargs))
  586. @register_model
  587. def edgenext_x_small(pretrained=False, **kwargs) -> EdgeNeXt:
  588. # 2.34M & 538.0M @ 256 resolution
  589. # 75.00% Top-1 accuracy
  590. # No AA, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
  591. # Jetson FPS=31.61 versus 28.49 for MobileViT_XS
  592. # For A100: FPS @ BS=1: 179.55 & @ BS=256: 4404.95 versus FPS @ BS=1: 94.55 & @ BS=256: 2361.53 for MobileViT_XS
  593. model_args = dict(depths=(3, 3, 9, 3), dims=(32, 64, 100, 192), heads=(4, 4, 4, 4))
  594. return _create_edgenext('edgenext_x_small', pretrained=pretrained, **dict(model_args, **kwargs))
  595. @register_model
  596. def edgenext_small(pretrained=False, **kwargs) -> EdgeNeXt:
  597. # 5.59M & 1260.59M @ 256 resolution
  598. # 79.43% Top-1 accuracy
  599. # AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
  600. # Jetson FPS=20.47 versus 18.86 for MobileViT_S
  601. # For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
  602. model_args = dict(depths=(3, 3, 9, 3), dims=(48, 96, 160, 304))
  603. return _create_edgenext('edgenext_small', pretrained=pretrained, **dict(model_args, **kwargs))
  604. @register_model
  605. def edgenext_base(pretrained=False, **kwargs) -> EdgeNeXt:
  606. # 18.51M & 3840.93M @ 256 resolution
  607. # 82.5% (normal) 83.7% (USI) Top-1 accuracy
  608. # AA=True, Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
  609. # Jetson FPS=xx.xx versus xx.xx for MobileViT_S
  610. # For A100: FPS @ BS=1: xxx.xx & @ BS=256: xxxx.xx
  611. model_args = dict(depths=[3, 3, 9, 3], dims=[80, 160, 288, 584])
  612. return _create_edgenext('edgenext_base', pretrained=pretrained, **dict(model_args, **kwargs))
  613. @register_model
  614. def edgenext_small_rw(pretrained=False, **kwargs) -> EdgeNeXt:
  615. model_args = dict(
  616. depths=(3, 3, 9, 3), dims=(48, 96, 192, 384),
  617. downsample_block=True, conv_bias=False, stem_type='overlap')
  618. return _create_edgenext('edgenext_small_rw', pretrained=pretrained, **dict(model_args, **kwargs))