repvit.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692
  1. """ RepViT
  2. Paper: `RepViT: Revisiting Mobile CNN From ViT Perspective`
  3. - https://arxiv.org/abs/2307.09283
  4. @misc{wang2023repvit,
  5. title={RepViT: Revisiting Mobile CNN From ViT Perspective},
  6. author={Ao Wang and Hui Chen and Zijia Lin and Hengjun Pu and Guiguang Ding},
  7. year={2023},
  8. eprint={2307.09283},
  9. archivePrefix={arXiv},
  10. primaryClass={cs.CV}
  11. }
  12. Adapted from official impl at https://github.com/jameslahm/RepViT
  13. """
  14. from typing import List, Optional, Tuple, Union, Type
  15. import torch
  16. import torch.nn as nn
  17. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  18. from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
  19. from ._builder import build_model_with_cfg
  20. from ._features import feature_take_indices
  21. from ._manipulate import checkpoint, checkpoint_seq
  22. from ._registry import register_model, generate_default_cfgs
  23. __all__ = ['RepVit']
  24. class ConvNorm(nn.Sequential):
  25. def __init__(
  26. self,
  27. in_dim: int,
  28. out_dim: int,
  29. ks: int = 1,
  30. stride: int = 1,
  31. pad: int = 0,
  32. dilation: int = 1,
  33. groups: int = 1,
  34. bn_weight_init: float = 1,
  35. device=None,
  36. dtype=None,
  37. ):
  38. dd = {'device': device, 'dtype': dtype}
  39. super().__init__()
  40. self.add_module('c', nn.Conv2d(in_dim, out_dim, ks, stride, pad, dilation, groups, bias=False, **dd))
  41. self.add_module('bn', nn.BatchNorm2d(out_dim, **dd))
  42. nn.init.constant_(self.bn.weight, bn_weight_init)
  43. nn.init.constant_(self.bn.bias, 0)
  44. @torch.no_grad()
  45. def fuse(self):
  46. c, bn = self._modules.values()
  47. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  48. w = c.weight * w[:, None, None, None]
  49. b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
  50. m = nn.Conv2d(
  51. w.size(1) * self.c.groups,
  52. w.size(0),
  53. w.shape[2:],
  54. stride=self.c.stride,
  55. padding=self.c.padding,
  56. dilation=self.c.dilation,
  57. groups=self.c.groups,
  58. device=c.weight.device,
  59. )
  60. m.weight.data.copy_(w)
  61. m.bias.data.copy_(b)
  62. return m
  63. class NormLinear(nn.Sequential):
  64. def __init__(
  65. self,
  66. in_dim: int,
  67. out_dim: int,
  68. bias: bool = True,
  69. std: float = 0.02,
  70. device=None,
  71. dtype=None,
  72. ):
  73. dd = {'device': device, 'dtype': dtype}
  74. super().__init__()
  75. self.add_module('bn', nn.BatchNorm1d(in_dim, **dd))
  76. self.add_module('l', nn.Linear(in_dim, out_dim, bias=bias, **dd))
  77. trunc_normal_(self.l.weight, std=std)
  78. if bias:
  79. nn.init.constant_(self.l.bias, 0)
  80. @torch.no_grad()
  81. def fuse(self):
  82. bn, l = self._modules.values()
  83. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  84. b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
  85. w = l.weight * w[None, :]
  86. if l.bias is None:
  87. b = b @ self.l.weight.T
  88. else:
  89. b = (l.weight @ b[:, None]).view(-1) + self.l.bias
  90. m = nn.Linear(w.size(1), w.size(0), device=l.weight.device)
  91. m.weight.data.copy_(w)
  92. m.bias.data.copy_(b)
  93. return m
  94. class RepVggDw(nn.Module):
  95. def __init__(
  96. self,
  97. ed: int,
  98. kernel_size: int,
  99. legacy: bool = False,
  100. device=None,
  101. dtype=None,
  102. ):
  103. dd = {'device': device, 'dtype': dtype}
  104. super().__init__()
  105. self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed, **dd)
  106. if legacy:
  107. self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed, **dd)
  108. # Make torchscript happy.
  109. self.bn = nn.Identity()
  110. else:
  111. self.conv1 = nn.Conv2d(ed, ed, 1, 1, 0, groups=ed, **dd)
  112. self.bn = nn.BatchNorm2d(ed, **dd)
  113. self.dim = ed
  114. self.legacy = legacy
  115. def forward(self, x):
  116. return self.bn(self.conv(x) + self.conv1(x) + x)
  117. @torch.no_grad()
  118. def fuse(self):
  119. conv = self.conv.fuse()
  120. if self.legacy:
  121. conv1 = self.conv1.fuse()
  122. else:
  123. conv1 = self.conv1
  124. conv_w = conv.weight
  125. conv_b = conv.bias
  126. conv1_w = conv1.weight
  127. conv1_b = conv1.bias
  128. conv1_w = nn.functional.pad(conv1_w, [1, 1, 1, 1])
  129. identity = nn.functional.pad(
  130. torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1, 1, 1, 1]
  131. )
  132. final_conv_w = conv_w + conv1_w + identity
  133. final_conv_b = conv_b + conv1_b
  134. conv.weight.data.copy_(final_conv_w)
  135. conv.bias.data.copy_(final_conv_b)
  136. if not self.legacy:
  137. bn = self.bn
  138. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  139. w = conv.weight * w[:, None, None, None]
  140. b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / (bn.running_var + bn.eps) ** 0.5
  141. conv.weight.data.copy_(w)
  142. conv.bias.data.copy_(b)
  143. return conv
  144. class RepVitMlp(nn.Module):
  145. def __init__(
  146. self,
  147. in_dim: int,
  148. hidden_dim: int,
  149. act_layer: Type[nn.Module],
  150. device=None,
  151. dtype=None,
  152. ):
  153. dd = {'device': device, 'dtype': dtype}
  154. super().__init__()
  155. self.conv1 = ConvNorm(in_dim, hidden_dim, 1, 1, 0, **dd)
  156. self.act = act_layer()
  157. self.conv2 = ConvNorm(hidden_dim, in_dim, 1, 1, 0, bn_weight_init=0, **dd)
  158. def forward(self, x):
  159. return self.conv2(self.act(self.conv1(x)))
  160. class RepViTBlock(nn.Module):
  161. def __init__(
  162. self,
  163. in_dim: int,
  164. mlp_ratio: float,
  165. kernel_size: int,
  166. use_se: bool,
  167. act_layer: Type[nn.Module],
  168. legacy: bool = False,
  169. device=None,
  170. dtype=None,
  171. ):
  172. dd = {'device': device, 'dtype': dtype}
  173. super().__init__()
  174. self.token_mixer = RepVggDw(in_dim, kernel_size, legacy, **dd)
  175. self.se = SqueezeExcite(in_dim, 0.25, **dd) if use_se else nn.Identity()
  176. self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer, **dd)
  177. def forward(self, x):
  178. x = self.token_mixer(x)
  179. x = self.se(x)
  180. identity = x
  181. x = self.channel_mixer(x)
  182. return identity + x
  183. class RepVitStem(nn.Module):
  184. def __init__(
  185. self,
  186. in_chs: int,
  187. out_chs: int,
  188. act_layer: Type[nn.Module],
  189. device=None,
  190. dtype=None,
  191. ):
  192. dd = {'device': device, 'dtype': dtype}
  193. super().__init__()
  194. self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1, **dd)
  195. self.act1 = act_layer()
  196. self.conv2 = ConvNorm(out_chs // 2, out_chs, 3, 2, 1, **dd)
  197. self.stride = 4
  198. def forward(self, x):
  199. return self.conv2(self.act1(self.conv1(x)))
  200. class RepVitDownsample(nn.Module):
  201. def __init__(
  202. self,
  203. in_dim: int,
  204. mlp_ratio: float,
  205. out_dim: int,
  206. kernel_size: int,
  207. act_layer: Type[nn.Module],
  208. legacy: bool = False,
  209. device=None,
  210. dtype=None,
  211. ):
  212. dd = {'device': device, 'dtype': dtype}
  213. super().__init__()
  214. self.pre_block = RepViTBlock(
  215. in_dim,
  216. mlp_ratio,
  217. kernel_size,
  218. use_se=False,
  219. act_layer=act_layer,
  220. legacy=legacy,
  221. **dd,
  222. )
  223. self.spatial_downsample = ConvNorm(
  224. in_dim,
  225. in_dim,
  226. kernel_size,
  227. stride=2,
  228. pad=(kernel_size - 1) // 2,
  229. groups=in_dim,
  230. **dd,
  231. )
  232. self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1, **dd)
  233. self.ffn = RepVitMlp(out_dim, out_dim * mlp_ratio, act_layer, **dd)
  234. def forward(self, x):
  235. x = self.pre_block(x)
  236. x = self.spatial_downsample(x)
  237. x = self.channel_downsample(x)
  238. identity = x
  239. x = self.ffn(x)
  240. return x + identity
  241. class RepVitClassifier(nn.Module):
  242. def __init__(
  243. self,
  244. dim: int,
  245. num_classes: int,
  246. distillation: bool = False,
  247. drop: float = 0.0,
  248. device=None,
  249. dtype=None,
  250. ):
  251. dd = {'device': device, 'dtype': dtype}
  252. super().__init__()
  253. self.head_drop = nn.Dropout(drop)
  254. self.head = NormLinear(dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  255. self.distillation = distillation
  256. self.distilled_training = False
  257. self.num_classes = num_classes
  258. if distillation:
  259. self.head_dist = NormLinear(dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  260. def forward(self, x):
  261. x = self.head_drop(x)
  262. if self.distillation:
  263. x1, x2 = self.head(x), self.head_dist(x)
  264. if self.training and self.distilled_training and not torch.jit.is_scripting():
  265. return x1, x2
  266. else:
  267. return (x1 + x2) / 2
  268. else:
  269. x = self.head(x)
  270. return x
  271. @torch.no_grad()
  272. def fuse(self):
  273. if not self.num_classes > 0:
  274. return nn.Identity()
  275. head = self.head.fuse()
  276. if self.distillation:
  277. head_dist = self.head_dist.fuse()
  278. head.weight += head_dist.weight
  279. head.bias += head_dist.bias
  280. head.weight /= 2
  281. head.bias /= 2
  282. return head
  283. else:
  284. return head
  285. class RepVitStage(nn.Module):
  286. def __init__(
  287. self,
  288. in_dim: int,
  289. out_dim: int,
  290. depth: int,
  291. mlp_ratio: float,
  292. act_layer: Type[nn.Module],
  293. kernel_size: int = 3,
  294. downsample: bool = True,
  295. legacy: bool = False,
  296. device=None,
  297. dtype=None,
  298. ):
  299. dd = {'device': device, 'dtype': dtype}
  300. super().__init__()
  301. if downsample:
  302. self.downsample = RepVitDownsample(
  303. in_dim,
  304. mlp_ratio,
  305. out_dim,
  306. kernel_size,
  307. act_layer=act_layer,
  308. legacy=legacy,
  309. **dd,
  310. )
  311. else:
  312. assert in_dim == out_dim
  313. self.downsample = nn.Identity()
  314. blocks = []
  315. use_se = True
  316. for _ in range(depth):
  317. blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy, **dd))
  318. use_se = not use_se
  319. self.blocks = nn.Sequential(*blocks)
  320. def forward(self, x):
  321. x = self.downsample(x)
  322. x = self.blocks(x)
  323. return x
  324. class RepVit(nn.Module):
  325. def __init__(
  326. self,
  327. in_chans: int = 3,
  328. img_size: int = 224,
  329. embed_dim: Tuple[int, ...] = (48,),
  330. depth: Tuple[int, ...] = (2,),
  331. mlp_ratio: float = 2,
  332. global_pool: str = 'avg',
  333. kernel_size: int = 3,
  334. num_classes: int = 1000,
  335. act_layer: Type[nn.Module] = nn.GELU,
  336. distillation: bool = True,
  337. drop_rate: float = 0.0,
  338. legacy: bool = False,
  339. device=None,
  340. dtype=None,
  341. ):
  342. super().__init__()
  343. dd = {'device': device, 'dtype': dtype}
  344. self.grad_checkpointing = False
  345. self.global_pool = global_pool
  346. self.embed_dim = embed_dim
  347. self.num_classes = num_classes
  348. in_dim = embed_dim[0]
  349. self.stem = RepVitStem(in_chans, in_dim, act_layer, **dd)
  350. stride = self.stem.stride
  351. resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))])
  352. num_stages = len(embed_dim)
  353. mlp_ratios = to_ntuple(num_stages)(mlp_ratio)
  354. self.feature_info = []
  355. stages = []
  356. for i in range(num_stages):
  357. downsample = True if i != 0 else False
  358. stages.append(
  359. RepVitStage(
  360. in_dim,
  361. embed_dim[i],
  362. depth[i],
  363. mlp_ratio=mlp_ratios[i],
  364. act_layer=act_layer,
  365. kernel_size=kernel_size,
  366. downsample=downsample,
  367. legacy=legacy,
  368. **dd,
  369. )
  370. )
  371. stage_stride = 2 if downsample else 1
  372. stride *= stage_stride
  373. resolution = tuple([(r - 1) // stage_stride + 1 for r in resolution])
  374. self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')]
  375. in_dim = embed_dim[i]
  376. self.stages = nn.Sequential(*stages)
  377. self.num_features = self.head_hidden_size = embed_dim[-1]
  378. self.head_drop = nn.Dropout(drop_rate)
  379. self.head = RepVitClassifier(embed_dim[-1], num_classes, distillation, **dd)
  380. @torch.jit.ignore
  381. def group_matcher(self, coarse=False):
  382. matcher = dict(stem=r'^stem', blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]) # stem and embed
  383. return matcher
  384. @torch.jit.ignore
  385. def set_grad_checkpointing(self, enable=True):
  386. self.grad_checkpointing = enable
  387. @torch.jit.ignore
  388. def get_classifier(self) -> nn.Module:
  389. return self.head
  390. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation: bool = False, device=None, dtype=None):
  391. self.num_classes = num_classes
  392. if global_pool is not None:
  393. self.global_pool = global_pool
  394. dd = {'device': device, 'dtype': dtype}
  395. self.head = RepVitClassifier(self.embed_dim[-1], num_classes, distillation, **dd)
  396. @torch.jit.ignore
  397. def set_distilled_training(self, enable=True):
  398. self.head.distilled_training = enable
  399. def forward_intermediates(
  400. self,
  401. x: torch.Tensor,
  402. indices: Optional[Union[int, List[int]]] = None,
  403. norm: bool = False,
  404. stop_early: bool = False,
  405. output_fmt: str = 'NCHW',
  406. intermediates_only: bool = False,
  407. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  408. """ Forward features that returns intermediates.
  409. Args:
  410. x: Input image tensor
  411. indices: Take last n blocks if int, all if None, select matching indices if sequence
  412. norm: Apply norm layer to compatible intermediates
  413. stop_early: Stop iterating over blocks when last desired intermediate hit
  414. output_fmt: Shape of intermediate feature outputs
  415. intermediates_only: Only return intermediate features
  416. Returns:
  417. """
  418. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  419. intermediates = []
  420. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  421. # forward pass
  422. x = self.stem(x)
  423. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  424. stages = self.stages
  425. else:
  426. stages = self.stages[:max_index + 1]
  427. for feat_idx, stage in enumerate(stages):
  428. if self.grad_checkpointing and not torch.jit.is_scripting():
  429. x = checkpoint(stage, x)
  430. else:
  431. x = stage(x)
  432. if feat_idx in take_indices:
  433. intermediates.append(x)
  434. if intermediates_only:
  435. return intermediates
  436. return x, intermediates
  437. def prune_intermediate_layers(
  438. self,
  439. indices: Union[int, List[int]] = 1,
  440. prune_norm: bool = False,
  441. prune_head: bool = True,
  442. ):
  443. """ Prune layers not required for specified intermediates.
  444. """
  445. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  446. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  447. if prune_head:
  448. self.reset_classifier(0, '')
  449. return take_indices
  450. def forward_features(self, x):
  451. x = self.stem(x)
  452. if self.grad_checkpointing and not torch.jit.is_scripting():
  453. x = checkpoint_seq(self.stages, x)
  454. else:
  455. x = self.stages(x)
  456. return x
  457. def forward_head(self, x, pre_logits: bool = False):
  458. if self.global_pool == 'avg':
  459. x = x.mean((2, 3), keepdim=False)
  460. x = self.head_drop(x)
  461. if pre_logits:
  462. return x
  463. return self.head(x)
  464. def forward(self, x):
  465. x = self.forward_features(x)
  466. x = self.forward_head(x)
  467. return x
  468. @torch.no_grad()
  469. def fuse(self):
  470. def fuse_children(net):
  471. for child_name, child in net.named_children():
  472. if hasattr(child, 'fuse'):
  473. fused = child.fuse()
  474. setattr(net, child_name, fused)
  475. fuse_children(fused)
  476. else:
  477. fuse_children(child)
  478. fuse_children(self)
  479. def _cfg(url='', **kwargs):
  480. return {
  481. 'url': url,
  482. 'num_classes': 1000,
  483. 'input_size': (3, 224, 224),
  484. 'pool_size': (7, 7),
  485. 'crop_pct': 0.95,
  486. 'interpolation': 'bicubic',
  487. 'mean': IMAGENET_DEFAULT_MEAN,
  488. 'std': IMAGENET_DEFAULT_STD,
  489. 'first_conv': 'stem.conv1.c',
  490. 'classifier': ('head.head.l', 'head.head_dist.l'),
  491. 'license': 'apache-2.0',
  492. **kwargs,
  493. }
  494. default_cfgs = generate_default_cfgs(
  495. {
  496. 'repvit_m1.dist_in1k': _cfg(
  497. hf_hub_id='timm/',
  498. ),
  499. 'repvit_m2.dist_in1k': _cfg(
  500. hf_hub_id='timm/',
  501. ),
  502. 'repvit_m3.dist_in1k': _cfg(
  503. hf_hub_id='timm/',
  504. ),
  505. 'repvit_m0_9.dist_300e_in1k': _cfg(
  506. hf_hub_id='timm/',
  507. ),
  508. 'repvit_m0_9.dist_450e_in1k': _cfg(
  509. hf_hub_id='timm/',
  510. ),
  511. 'repvit_m1_0.dist_300e_in1k': _cfg(
  512. hf_hub_id='timm/',
  513. ),
  514. 'repvit_m1_0.dist_450e_in1k': _cfg(
  515. hf_hub_id='timm/',
  516. ),
  517. 'repvit_m1_1.dist_300e_in1k': _cfg(
  518. hf_hub_id='timm/',
  519. ),
  520. 'repvit_m1_1.dist_450e_in1k': _cfg(
  521. hf_hub_id='timm/',
  522. ),
  523. 'repvit_m1_5.dist_300e_in1k': _cfg(
  524. hf_hub_id='timm/',
  525. ),
  526. 'repvit_m1_5.dist_450e_in1k': _cfg(
  527. hf_hub_id='timm/',
  528. ),
  529. 'repvit_m2_3.dist_300e_in1k': _cfg(
  530. hf_hub_id='timm/',
  531. ),
  532. 'repvit_m2_3.dist_450e_in1k': _cfg(
  533. hf_hub_id='timm/',
  534. ),
  535. }
  536. )
  537. def _create_repvit(variant, pretrained=False, **kwargs):
  538. out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
  539. model = build_model_with_cfg(
  540. RepVit,
  541. variant,
  542. pretrained,
  543. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  544. **kwargs,
  545. )
  546. return model
  547. @register_model
  548. def repvit_m1(pretrained=False, **kwargs):
  549. """
  550. Constructs a RepViT-M1 model
  551. """
  552. model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2), legacy=True)
  553. return _create_repvit('repvit_m1', pretrained=pretrained, **dict(model_args, **kwargs))
  554. @register_model
  555. def repvit_m2(pretrained=False, **kwargs):
  556. """
  557. Constructs a RepViT-M2 model
  558. """
  559. model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2), legacy=True)
  560. return _create_repvit('repvit_m2', pretrained=pretrained, **dict(model_args, **kwargs))
  561. @register_model
  562. def repvit_m3(pretrained=False, **kwargs):
  563. """
  564. Constructs a RepViT-M3 model
  565. """
  566. model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 18, 2), legacy=True)
  567. return _create_repvit('repvit_m3', pretrained=pretrained, **dict(model_args, **kwargs))
  568. @register_model
  569. def repvit_m0_9(pretrained=False, **kwargs):
  570. """
  571. Constructs a RepViT-M0.9 model
  572. """
  573. model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2))
  574. return _create_repvit('repvit_m0_9', pretrained=pretrained, **dict(model_args, **kwargs))
  575. @register_model
  576. def repvit_m1_0(pretrained=False, **kwargs):
  577. """
  578. Constructs a RepViT-M1.0 model
  579. """
  580. model_args = dict(embed_dim=(56, 112, 224, 448), depth=(2, 2, 14, 2))
  581. return _create_repvit('repvit_m1_0', pretrained=pretrained, **dict(model_args, **kwargs))
  582. @register_model
  583. def repvit_m1_1(pretrained=False, **kwargs):
  584. """
  585. Constructs a RepViT-M1.1 model
  586. """
  587. model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2))
  588. return _create_repvit('repvit_m1_1', pretrained=pretrained, **dict(model_args, **kwargs))
  589. @register_model
  590. def repvit_m1_5(pretrained=False, **kwargs):
  591. """
  592. Constructs a RepViT-M1.5 model
  593. """
  594. model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 24, 4))
  595. return _create_repvit('repvit_m1_5', pretrained=pretrained, **dict(model_args, **kwargs))
  596. @register_model
  597. def repvit_m2_3(pretrained=False, **kwargs):
  598. """
  599. Constructs a RepViT-M2.3 model
  600. """
  601. model_args = dict(embed_dim=(80, 160, 320, 640), depth=(6, 6, 34, 2))
  602. return _create_repvit('repvit_m2_3', pretrained=pretrained, **dict(model_args, **kwargs))