efficientvit_mit.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279
  1. """ EfficientViT (by MIT Song Han's Lab)
  2. Paper: `Efficientvit: Enhanced linear attention for high-resolution low-computation visual recognition`
  3. - https://arxiv.org/abs/2205.14756
  4. Adapted from official impl at https://github.com/mit-han-lab/efficientvit
  5. """
  6. __all__ = ['EfficientVit', 'EfficientVitLarge']
  7. from typing import List, Optional, Tuple, Type, Union
  8. from functools import partial
  9. import torch
  10. import torch.nn as nn
  11. import torch.nn.functional as F
  12. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  13. from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
  14. from ._builder import build_model_with_cfg
  15. from ._features import feature_take_indices
  16. from ._features_fx import register_notrace_module
  17. from ._manipulate import checkpoint_seq
  18. from ._registry import register_model, generate_default_cfgs
  19. def val2list(x: list or tuple or any, repeat_time=1):
  20. if isinstance(x, (list, tuple)):
  21. return list(x)
  22. return [x for _ in range(repeat_time)]
  23. def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1):
  24. # repeat elements if necessary
  25. x = val2list(x)
  26. if len(x) > 0:
  27. x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
  28. return tuple(x)
  29. def get_same_padding(kernel_size: int or tuple[int, ...]) -> int or tuple[int, ...]:
  30. if isinstance(kernel_size, tuple):
  31. return tuple([get_same_padding(ks) for ks in kernel_size])
  32. else:
  33. assert kernel_size % 2 > 0, "kernel size should be odd number"
  34. return kernel_size // 2
  35. class ConvNormAct(nn.Module):
  36. def __init__(
  37. self,
  38. in_channels: int,
  39. out_channels: int,
  40. kernel_size: Union[int, Tuple[int, int]] = 3,
  41. stride: int = 1,
  42. dilation: int = 1,
  43. groups: int = 1,
  44. bias: bool = False,
  45. dropout: float = 0.,
  46. norm_layer: Optional[Type[nn.Module]] = nn.BatchNorm2d,
  47. act_layer: Optional[Type[nn.Module]] = nn.ReLU,
  48. device=None,
  49. dtype=None,
  50. ):
  51. dd = {'device': device, 'dtype': dtype}
  52. super().__init__()
  53. self.dropout = nn.Dropout(dropout, inplace=False)
  54. self.conv = create_conv2d(
  55. in_channels,
  56. out_channels,
  57. kernel_size=kernel_size,
  58. stride=stride,
  59. dilation=dilation,
  60. groups=groups,
  61. bias=bias,
  62. **dd,
  63. )
  64. self.norm = norm_layer(num_features=out_channels, **dd) if norm_layer else nn.Identity()
  65. self.act = act_layer(inplace=True) if act_layer is not None else nn.Identity()
  66. def forward(self, x):
  67. x = self.dropout(x)
  68. x = self.conv(x)
  69. x = self.norm(x)
  70. x = self.act(x)
  71. return x
  72. class DSConv(nn.Module):
  73. def __init__(
  74. self,
  75. in_channels: int,
  76. out_channels: int,
  77. kernel_size: int = 3,
  78. stride: int = 1,
  79. use_bias: Union[bool, Tuple[bool, bool]] = False,
  80. norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = nn.BatchNorm2d,
  81. act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (nn.ReLU6, None),
  82. device=None,
  83. dtype=None,
  84. ):
  85. dd = {'device': device, 'dtype': dtype}
  86. super().__init__()
  87. use_bias = val2tuple(use_bias, 2)
  88. norm_layer = val2tuple(norm_layer, 2)
  89. act_layer = val2tuple(act_layer, 2)
  90. self.depth_conv = ConvNormAct(
  91. in_channels,
  92. in_channels,
  93. kernel_size,
  94. stride,
  95. groups=in_channels,
  96. norm_layer=norm_layer[0],
  97. act_layer=act_layer[0],
  98. bias=use_bias[0],
  99. **dd,
  100. )
  101. self.point_conv = ConvNormAct(
  102. in_channels,
  103. out_channels,
  104. 1,
  105. norm_layer=norm_layer[1],
  106. act_layer=act_layer[1],
  107. bias=use_bias[1],
  108. **dd,
  109. )
  110. def forward(self, x):
  111. x = self.depth_conv(x)
  112. x = self.point_conv(x)
  113. return x
  114. class ConvBlock(nn.Module):
  115. def __init__(
  116. self,
  117. in_channels: int,
  118. out_channels: int,
  119. kernel_size: int = 3,
  120. stride: int = 1,
  121. mid_channels: Optional[int] = None,
  122. expand_ratio: float = 1,
  123. use_bias: Union[bool, Tuple[bool, bool]] = False,
  124. norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = nn.BatchNorm2d,
  125. act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (nn.ReLU6, None),
  126. device=None,
  127. dtype=None,
  128. ):
  129. dd = {'device': device, 'dtype': dtype}
  130. super().__init__()
  131. use_bias = val2tuple(use_bias, 2)
  132. norm_layer = val2tuple(norm_layer, 2)
  133. act_layer = val2tuple(act_layer, 2)
  134. mid_channels = mid_channels or round(in_channels * expand_ratio)
  135. self.conv1 = ConvNormAct(
  136. in_channels,
  137. mid_channels,
  138. kernel_size,
  139. stride,
  140. norm_layer=norm_layer[0],
  141. act_layer=act_layer[0],
  142. bias=use_bias[0],
  143. **dd,
  144. )
  145. self.conv2 = ConvNormAct(
  146. mid_channels,
  147. out_channels,
  148. kernel_size,
  149. 1,
  150. norm_layer=norm_layer[1],
  151. act_layer=act_layer[1],
  152. bias=use_bias[1],
  153. **dd,
  154. )
  155. def forward(self, x):
  156. x = self.conv1(x)
  157. x = self.conv2(x)
  158. return x
  159. class MBConv(nn.Module):
  160. def __init__(
  161. self,
  162. in_channels: int,
  163. out_channels: int,
  164. kernel_size: int = 3,
  165. stride: int = 1,
  166. mid_channels: Optional[int] = None,
  167. expand_ratio: float = 6,
  168. use_bias: Union[bool, Tuple[bool, ...]] = False,
  169. norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = nn.BatchNorm2d,
  170. act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (nn.ReLU6, nn.ReLU6, None),
  171. device=None,
  172. dtype=None,
  173. ):
  174. dd = {'device': device, 'dtype': dtype}
  175. super().__init__()
  176. use_bias = val2tuple(use_bias, 3)
  177. norm_layer = val2tuple(norm_layer, 3)
  178. act_layer = val2tuple(act_layer, 3)
  179. mid_channels = mid_channels or round(in_channels * expand_ratio)
  180. self.inverted_conv = ConvNormAct(
  181. in_channels,
  182. mid_channels,
  183. 1,
  184. stride=1,
  185. norm_layer=norm_layer[0],
  186. act_layer=act_layer[0],
  187. bias=use_bias[0],
  188. **dd,
  189. )
  190. self.depth_conv = ConvNormAct(
  191. mid_channels,
  192. mid_channels,
  193. kernel_size,
  194. stride=stride,
  195. groups=mid_channels,
  196. norm_layer=norm_layer[1],
  197. act_layer=act_layer[1],
  198. bias=use_bias[1],
  199. **dd,
  200. )
  201. self.point_conv = ConvNormAct(
  202. mid_channels,
  203. out_channels,
  204. 1,
  205. norm_layer=norm_layer[2],
  206. act_layer=act_layer[2],
  207. bias=use_bias[2],
  208. **dd,
  209. )
  210. def forward(self, x):
  211. x = self.inverted_conv(x)
  212. x = self.depth_conv(x)
  213. x = self.point_conv(x)
  214. return x
  215. class FusedMBConv(nn.Module):
  216. def __init__(
  217. self,
  218. in_channels: int,
  219. out_channels: int,
  220. kernel_size: int = 3,
  221. stride: int = 1,
  222. mid_channels: Optional[int] = None,
  223. expand_ratio: float = 6,
  224. groups: int = 1,
  225. use_bias: Union[bool, Tuple[bool, ...]] = False,
  226. norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = nn.BatchNorm2d,
  227. act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (nn.ReLU6, None),
  228. device=None,
  229. dtype=None,
  230. ):
  231. dd = {'device': device, 'dtype': dtype}
  232. super().__init__()
  233. use_bias = val2tuple(use_bias, 2)
  234. norm_layer = val2tuple(norm_layer, 2)
  235. act_layer = val2tuple(act_layer, 2)
  236. mid_channels = mid_channels or round(in_channels * expand_ratio)
  237. self.spatial_conv = ConvNormAct(
  238. in_channels,
  239. mid_channels,
  240. kernel_size,
  241. stride=stride,
  242. groups=groups,
  243. norm_layer=norm_layer[0],
  244. act_layer=act_layer[0],
  245. bias=use_bias[0],
  246. **dd,
  247. )
  248. self.point_conv = ConvNormAct(
  249. mid_channels,
  250. out_channels,
  251. 1,
  252. norm_layer=norm_layer[1],
  253. act_layer=act_layer[1],
  254. bias=use_bias[1],
  255. **dd,
  256. )
  257. def forward(self, x):
  258. x = self.spatial_conv(x)
  259. x = self.point_conv(x)
  260. return x
  261. class LiteMLA(nn.Module):
  262. """Lightweight multi-scale linear attention"""
  263. def __init__(
  264. self,
  265. in_channels: int,
  266. out_channels: int,
  267. heads: Optional[int] = None,
  268. heads_ratio: float = 1.0,
  269. dim: int = 8,
  270. use_bias: Union[bool, Tuple[bool, ...]] = False,
  271. norm_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (None, nn.BatchNorm2d),
  272. act_layer: Union[Type[nn.Module], Tuple[Optional[Type[nn.Module]], ...]] = (None, None),
  273. kernel_func: Type[nn.Module] = nn.ReLU,
  274. scales: Tuple[int, ...] = (5,),
  275. eps: float = 1e-5,
  276. device=None,
  277. dtype=None,
  278. ):
  279. dd = {'device': device, 'dtype': dtype}
  280. super().__init__()
  281. self.eps = eps
  282. heads = heads or int(in_channels // dim * heads_ratio)
  283. total_dim = heads * dim
  284. use_bias = val2tuple(use_bias, 2)
  285. norm_layer = val2tuple(norm_layer, 2)
  286. act_layer = val2tuple(act_layer, 2)
  287. self.dim = dim
  288. self.qkv = ConvNormAct(
  289. in_channels,
  290. 3 * total_dim,
  291. 1,
  292. bias=use_bias[0],
  293. norm_layer=norm_layer[0],
  294. act_layer=act_layer[0],
  295. **dd,
  296. )
  297. self.aggreg = nn.ModuleList([
  298. nn.Sequential(
  299. nn.Conv2d(
  300. 3 * total_dim,
  301. 3 * total_dim,
  302. scale,
  303. padding=get_same_padding(scale),
  304. groups=3 * total_dim,
  305. bias=use_bias[0],
  306. **dd,
  307. ),
  308. nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0], **dd),
  309. )
  310. for scale in scales
  311. ])
  312. self.kernel_func = kernel_func(inplace=False)
  313. self.proj = ConvNormAct(
  314. total_dim * (1 + len(scales)),
  315. out_channels,
  316. 1,
  317. bias=use_bias[1],
  318. norm_layer=norm_layer[1],
  319. act_layer=act_layer[1],
  320. **dd,
  321. )
  322. def _attn(self, q, k, v):
  323. dtype = v.dtype
  324. q, k, v = q.float(), k.float(), v.float()
  325. kv = k.transpose(-1, -2) @ v
  326. out = q @ kv
  327. out = out[..., :-1] / (out[..., -1:] + self.eps)
  328. return out.to(dtype)
  329. def forward(self, x):
  330. B, _, H, W = x.shape
  331. # generate multi-scale q, k, v
  332. qkv = self.qkv(x)
  333. multi_scale_qkv = [qkv]
  334. for op in self.aggreg:
  335. multi_scale_qkv.append(op(qkv))
  336. multi_scale_qkv = torch.cat(multi_scale_qkv, dim=1)
  337. multi_scale_qkv = multi_scale_qkv.reshape(B, -1, 3 * self.dim, H * W).transpose(-1, -2)
  338. q, k, v = multi_scale_qkv.chunk(3, dim=-1)
  339. # lightweight global attention
  340. q = self.kernel_func(q)
  341. k = self.kernel_func(k)
  342. v = F.pad(v, (0, 1), mode="constant", value=1.)
  343. if not torch.jit.is_scripting():
  344. with torch.autocast(device_type=v.device.type, enabled=False):
  345. out = self._attn(q, k, v)
  346. else:
  347. out = self._attn(q, k, v)
  348. # final projection
  349. out = out.transpose(-1, -2).reshape(B, -1, H, W)
  350. out = self.proj(out)
  351. return out
  352. register_notrace_module(LiteMLA)
  353. class EfficientVitBlock(nn.Module):
  354. def __init__(
  355. self,
  356. in_channels: int,
  357. heads_ratio: float = 1.0,
  358. head_dim: int = 32,
  359. expand_ratio: float = 4,
  360. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  361. act_layer: Type[nn.Module] = nn.Hardswish,
  362. device=None,
  363. dtype=None,
  364. ):
  365. dd = {'device': device, 'dtype': dtype}
  366. super().__init__()
  367. self.context_module = ResidualBlock(
  368. LiteMLA(
  369. in_channels=in_channels,
  370. out_channels=in_channels,
  371. heads_ratio=heads_ratio,
  372. dim=head_dim,
  373. norm_layer=(None, norm_layer),
  374. **dd,
  375. ),
  376. nn.Identity(),
  377. )
  378. self.local_module = ResidualBlock(
  379. MBConv(
  380. in_channels=in_channels,
  381. out_channels=in_channels,
  382. expand_ratio=expand_ratio,
  383. use_bias=(True, True, False),
  384. norm_layer=(None, None, norm_layer),
  385. act_layer=(act_layer, act_layer, None),
  386. **dd,
  387. ),
  388. nn.Identity(),
  389. )
  390. def forward(self, x):
  391. x = self.context_module(x)
  392. x = self.local_module(x)
  393. return x
  394. class ResidualBlock(nn.Module):
  395. def __init__(
  396. self,
  397. main: Optional[nn.Module],
  398. shortcut: Optional[nn.Module] = None,
  399. pre_norm: Optional[nn.Module] = None,
  400. ):
  401. super().__init__()
  402. self.pre_norm = pre_norm if pre_norm is not None else nn.Identity()
  403. self.main = main
  404. self.shortcut = shortcut
  405. def forward(self, x):
  406. res = self.main(self.pre_norm(x))
  407. if self.shortcut is not None:
  408. res = res + self.shortcut(x)
  409. return res
  410. def build_local_block(
  411. in_channels: int,
  412. out_channels: int,
  413. stride: int,
  414. expand_ratio: float,
  415. norm_layer: str,
  416. act_layer: str,
  417. fewer_norm: bool = False,
  418. block_type: str = "default",
  419. device=None,
  420. dtype=None,
  421. ):
  422. dd = {'device': device, 'dtype': dtype}
  423. assert block_type in ["default", "large", "fused"]
  424. if expand_ratio == 1:
  425. if block_type == "default":
  426. block = DSConv(
  427. in_channels=in_channels,
  428. out_channels=out_channels,
  429. stride=stride,
  430. use_bias=(True, False) if fewer_norm else False,
  431. norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
  432. act_layer=(act_layer, None),
  433. **dd,
  434. )
  435. else:
  436. block = ConvBlock(
  437. in_channels=in_channels,
  438. out_channels=out_channels,
  439. stride=stride,
  440. use_bias=(True, False) if fewer_norm else False,
  441. norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
  442. act_layer=(act_layer, None),
  443. **dd,
  444. )
  445. else:
  446. if block_type == "default":
  447. block = MBConv(
  448. in_channels=in_channels,
  449. out_channels=out_channels,
  450. stride=stride,
  451. expand_ratio=expand_ratio,
  452. use_bias=(True, True, False) if fewer_norm else False,
  453. norm_layer=(None, None, norm_layer) if fewer_norm else norm_layer,
  454. act_layer=(act_layer, act_layer, None),
  455. **dd,
  456. )
  457. else:
  458. block = FusedMBConv(
  459. in_channels=in_channels,
  460. out_channels=out_channels,
  461. stride=stride,
  462. expand_ratio=expand_ratio,
  463. use_bias=(True, False) if fewer_norm else False,
  464. norm_layer=(None, norm_layer) if fewer_norm else norm_layer,
  465. act_layer=(act_layer, None),
  466. **dd,
  467. )
  468. return block
  469. class Stem(nn.Sequential):
  470. def __init__(
  471. self,
  472. in_chs: int,
  473. out_chs: int,
  474. depth: int,
  475. norm_layer: Type[nn.Module],
  476. act_layer: Type[nn.Module],
  477. block_type: str = 'default',
  478. device=None,
  479. dtype=None,
  480. ):
  481. super().__init__()
  482. dd = {'device': device, 'dtype': dtype}
  483. self.stride = 2
  484. self.add_module(
  485. 'in_conv',
  486. ConvNormAct(
  487. in_chs,
  488. out_chs,
  489. kernel_size=3,
  490. stride=2,
  491. norm_layer=norm_layer,
  492. act_layer=act_layer,
  493. **dd,
  494. )
  495. )
  496. stem_block = 0
  497. for _ in range(depth):
  498. self.add_module(f'res{stem_block}', ResidualBlock(
  499. build_local_block(
  500. in_channels=out_chs,
  501. out_channels=out_chs,
  502. stride=1,
  503. expand_ratio=1,
  504. norm_layer=norm_layer,
  505. act_layer=act_layer,
  506. block_type=block_type,
  507. **dd,
  508. ),
  509. nn.Identity(),
  510. ))
  511. stem_block += 1
  512. class EfficientVitStage(nn.Module):
  513. def __init__(
  514. self,
  515. in_chs: int,
  516. out_chs: int,
  517. depth: int,
  518. norm_layer: Type[nn.Module],
  519. act_layer: Type[nn.Module],
  520. expand_ratio: float,
  521. head_dim: int,
  522. vit_stage: bool = False,
  523. device=None,
  524. dtype=None,
  525. ):
  526. dd = {'device': device, 'dtype': dtype}
  527. super().__init__()
  528. blocks = [ResidualBlock(
  529. build_local_block(
  530. in_channels=in_chs,
  531. out_channels=out_chs,
  532. stride=2,
  533. expand_ratio=expand_ratio,
  534. norm_layer=norm_layer,
  535. act_layer=act_layer,
  536. fewer_norm=vit_stage,
  537. **dd,
  538. ),
  539. None,
  540. )]
  541. in_chs = out_chs
  542. if vit_stage:
  543. # for stage 3, 4
  544. for _ in range(depth):
  545. blocks.append(
  546. EfficientVitBlock(
  547. in_channels=in_chs,
  548. head_dim=head_dim,
  549. expand_ratio=expand_ratio,
  550. norm_layer=norm_layer,
  551. act_layer=act_layer,
  552. **dd,
  553. )
  554. )
  555. else:
  556. # for stage 1, 2
  557. for i in range(1, depth):
  558. blocks.append(ResidualBlock(
  559. build_local_block(
  560. in_channels=in_chs,
  561. out_channels=out_chs,
  562. stride=1,
  563. expand_ratio=expand_ratio,
  564. norm_layer=norm_layer,
  565. act_layer=act_layer,
  566. **dd,
  567. ),
  568. nn.Identity(),
  569. ))
  570. self.blocks = nn.Sequential(*blocks)
  571. def forward(self, x):
  572. return self.blocks(x)
  573. class EfficientVitLargeStage(nn.Module):
  574. def __init__(
  575. self,
  576. in_chs: int,
  577. out_chs: int,
  578. depth: int,
  579. norm_layer: Type[nn.Module],
  580. act_layer: Type[nn.Module],
  581. head_dim: int,
  582. vit_stage: bool = False,
  583. fewer_norm: bool = False,
  584. device=None,
  585. dtype=None,
  586. ):
  587. dd = {'device': device, 'dtype': dtype}
  588. super().__init__()
  589. blocks = [ResidualBlock(
  590. build_local_block(
  591. in_channels=in_chs,
  592. out_channels=out_chs,
  593. stride=2,
  594. expand_ratio=24 if vit_stage else 16,
  595. norm_layer=norm_layer,
  596. act_layer=act_layer,
  597. fewer_norm=vit_stage or fewer_norm,
  598. block_type='default' if fewer_norm else 'fused',
  599. **dd,
  600. ),
  601. None,
  602. )]
  603. in_chs = out_chs
  604. if vit_stage:
  605. # for stage 4
  606. for _ in range(depth):
  607. blocks.append(
  608. EfficientVitBlock(
  609. in_channels=in_chs,
  610. head_dim=head_dim,
  611. expand_ratio=6,
  612. norm_layer=norm_layer,
  613. act_layer=act_layer,
  614. **dd,
  615. )
  616. )
  617. else:
  618. # for stage 1, 2, 3
  619. for i in range(depth):
  620. blocks.append(ResidualBlock(
  621. build_local_block(
  622. in_channels=in_chs,
  623. out_channels=out_chs,
  624. stride=1,
  625. expand_ratio=4,
  626. norm_layer=norm_layer,
  627. act_layer=act_layer,
  628. fewer_norm=fewer_norm,
  629. block_type='default' if fewer_norm else 'fused',
  630. **dd,
  631. ),
  632. nn.Identity(),
  633. ))
  634. self.blocks = nn.Sequential(*blocks)
  635. def forward(self, x):
  636. return self.blocks(x)
  637. class ClassifierHead(nn.Module):
  638. def __init__(
  639. self,
  640. in_channels: int,
  641. widths: List[int],
  642. num_classes: int = 1000,
  643. dropout: float = 0.,
  644. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  645. act_layer: Optional[Type[nn.Module]] = nn.Hardswish,
  646. pool_type: str = 'avg',
  647. norm_eps: float = 1e-5,
  648. device=None,
  649. dtype=None,
  650. ):
  651. dd = {'device': device, 'dtype': dtype}
  652. super().__init__()
  653. self.widths = widths
  654. self.num_features = widths[-1]
  655. assert pool_type, 'Cannot disable pooling'
  656. self.in_conv = ConvNormAct(in_channels, widths[0], 1, norm_layer=norm_layer, act_layer=act_layer, **dd)
  657. self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True)
  658. self.classifier = nn.Sequential(
  659. nn.Linear(widths[0], widths[1], bias=False, **dd),
  660. nn.LayerNorm(widths[1], eps=norm_eps, **dd),
  661. act_layer(inplace=True) if act_layer is not None else nn.Identity(),
  662. nn.Dropout(dropout, inplace=False),
  663. nn.Linear(widths[1], num_classes, bias=True, **dd) if num_classes > 0 else nn.Identity(),
  664. )
  665. def reset(self, num_classes: int, pool_type: Optional[str] = None):
  666. if pool_type is not None:
  667. assert pool_type, 'Cannot disable pooling'
  668. self.global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=True,)
  669. if num_classes > 0:
  670. self.classifier[-1] = nn.Linear(self.num_features, num_classes, bias=True)
  671. else:
  672. self.classifier[-1] = nn.Identity()
  673. def forward(self, x, pre_logits: bool = False):
  674. x = self.in_conv(x)
  675. x = self.global_pool(x)
  676. if pre_logits:
  677. # cannot slice or iterate with torchscript so, this
  678. x = self.classifier[0](x)
  679. x = self.classifier[1](x)
  680. x = self.classifier[2](x)
  681. x = self.classifier[3](x)
  682. else:
  683. x = self.classifier(x)
  684. return x
  685. class EfficientVit(nn.Module):
  686. def __init__(
  687. self,
  688. in_chans: int = 3,
  689. widths: Tuple[int, ...] = (),
  690. depths: Tuple[int, ...] = (),
  691. head_dim: int = 32,
  692. expand_ratio: float = 4,
  693. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  694. act_layer: Type[nn.Module] = nn.Hardswish,
  695. global_pool: str = 'avg',
  696. head_widths: Tuple[int, ...] = (),
  697. drop_rate: float = 0.0,
  698. num_classes: int = 1000,
  699. device=None,
  700. dtype=None,
  701. ):
  702. dd = {'device': device, 'dtype': dtype}
  703. super().__init__()
  704. self.grad_checkpointing = False
  705. self.global_pool = global_pool
  706. self.num_classes = num_classes
  707. # input stem
  708. self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer, **dd)
  709. stride = self.stem.stride
  710. # stages
  711. self.feature_info = []
  712. self.stages = nn.Sequential()
  713. in_channels = widths[0]
  714. for i, (w, d) in enumerate(zip(widths[1:], depths[1:])):
  715. self.stages.append(EfficientVitStage(
  716. in_channels,
  717. w,
  718. depth=d,
  719. norm_layer=norm_layer,
  720. act_layer=act_layer,
  721. expand_ratio=expand_ratio,
  722. head_dim=head_dim,
  723. vit_stage=i >= 2,
  724. **dd,
  725. ))
  726. stride *= 2
  727. in_channels = w
  728. self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')]
  729. self.num_features = in_channels
  730. self.head = ClassifierHead(
  731. self.num_features,
  732. widths=head_widths,
  733. num_classes=num_classes,
  734. dropout=drop_rate,
  735. pool_type=self.global_pool,
  736. **dd,
  737. )
  738. self.head_hidden_size = self.head.num_features
  739. @torch.jit.ignore
  740. def group_matcher(self, coarse=False):
  741. matcher = dict(
  742. stem=r'^stem',
  743. blocks=r'^stages\.(\d+)' if coarse else [
  744. (r'^stages\.(\d+).downsample', (0,)),
  745. (r'^stages\.(\d+)\.\w+\.(\d+)', None),
  746. ]
  747. )
  748. return matcher
  749. @torch.jit.ignore
  750. def set_grad_checkpointing(self, enable=True):
  751. self.grad_checkpointing = enable
  752. @torch.jit.ignore
  753. def get_classifier(self) -> nn.Module:
  754. return self.head.classifier[-1]
  755. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  756. self.num_classes = num_classes
  757. self.head.reset(num_classes, global_pool)
  758. def forward_intermediates(
  759. self,
  760. x: torch.Tensor,
  761. indices: Optional[Union[int, List[int]]] = None,
  762. norm: bool = False,
  763. stop_early: bool = False,
  764. output_fmt: str = 'NCHW',
  765. intermediates_only: bool = False,
  766. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  767. """ Forward features that returns intermediates.
  768. Args:
  769. x: Input image tensor
  770. indices: Take last n blocks if int, all if None, select matching indices if sequence
  771. norm: Apply norm layer to compatible intermediates
  772. stop_early: Stop iterating over blocks when last desired intermediate hit
  773. output_fmt: Shape of intermediate feature outputs
  774. intermediates_only: Only return intermediate features
  775. Returns:
  776. """
  777. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  778. intermediates = []
  779. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  780. # forward pass
  781. x = self.stem(x)
  782. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  783. stages = self.stages
  784. else:
  785. stages = self.stages[:max_index + 1]
  786. for feat_idx, stage in enumerate(stages):
  787. if self.grad_checkpointing and not torch.jit.is_scripting():
  788. x = checkpoint_seq(stages, x)
  789. else:
  790. x = stage(x)
  791. if feat_idx in take_indices:
  792. intermediates.append(x)
  793. if intermediates_only:
  794. return intermediates
  795. return x, intermediates
  796. def prune_intermediate_layers(
  797. self,
  798. indices: Union[int, List[int]] = 1,
  799. prune_norm: bool = False,
  800. prune_head: bool = True,
  801. ):
  802. """ Prune layers not required for specified intermediates.
  803. """
  804. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  805. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  806. if prune_head:
  807. self.reset_classifier(0, '')
  808. return take_indices
  809. def forward_features(self, x):
  810. x = self.stem(x)
  811. if self.grad_checkpointing and not torch.jit.is_scripting():
  812. x = checkpoint_seq(self.stages, x)
  813. else:
  814. x = self.stages(x)
  815. return x
  816. def forward_head(self, x, pre_logits: bool = False):
  817. return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  818. def forward(self, x):
  819. x = self.forward_features(x)
  820. x = self.forward_head(x)
  821. return x
  822. class EfficientVitLarge(nn.Module):
  823. def __init__(
  824. self,
  825. in_chans: int = 3,
  826. widths: Tuple[int, ...] = (),
  827. depths: Tuple[int, ...] = (),
  828. head_dim: int = 32,
  829. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  830. act_layer: Type[nn.Module] = GELUTanh,
  831. global_pool: str = 'avg',
  832. head_widths: Tuple[int, ...] = (),
  833. drop_rate: float = 0.0,
  834. num_classes: int = 1000,
  835. norm_eps: float = 1e-7,
  836. device=None,
  837. dtype=None,
  838. ):
  839. dd = {'device': device, 'dtype': dtype}
  840. super().__init__()
  841. self.grad_checkpointing = False
  842. self.global_pool = global_pool
  843. self.num_classes = num_classes
  844. self.norm_eps = norm_eps
  845. norm_layer = partial(norm_layer, eps=self.norm_eps)
  846. # input stem
  847. self.stem = Stem(in_chans, widths[0], depths[0], norm_layer, act_layer, block_type='large', **dd)
  848. stride = self.stem.stride
  849. # stages
  850. self.feature_info = []
  851. self.stages = nn.Sequential()
  852. in_channels = widths[0]
  853. for i, (w, d) in enumerate(zip(widths[1:], depths[1:])):
  854. self.stages.append(EfficientVitLargeStage(
  855. in_channels,
  856. w,
  857. depth=d,
  858. norm_layer=norm_layer,
  859. act_layer=act_layer,
  860. head_dim=head_dim,
  861. vit_stage=i >= 3,
  862. fewer_norm=i >= 2,
  863. **dd,
  864. ))
  865. stride *= 2
  866. in_channels = w
  867. self.feature_info += [dict(num_chs=in_channels, reduction=stride, module=f'stages.{i}')]
  868. self.num_features = in_channels
  869. self.head = ClassifierHead(
  870. self.num_features,
  871. widths=head_widths,
  872. num_classes=num_classes,
  873. dropout=drop_rate,
  874. pool_type=self.global_pool,
  875. act_layer=act_layer,
  876. norm_eps=self.norm_eps,
  877. **dd,
  878. )
  879. self.head_hidden_size = self.head.num_features
  880. @torch.jit.ignore
  881. def group_matcher(self, coarse=False):
  882. matcher = dict(
  883. stem=r'^stem',
  884. blocks=r'^stages\.(\d+)' if coarse else [
  885. (r'^stages\.(\d+).downsample', (0,)),
  886. (r'^stages\.(\d+)\.\w+\.(\d+)', None),
  887. ]
  888. )
  889. return matcher
  890. @torch.jit.ignore
  891. def set_grad_checkpointing(self, enable=True):
  892. self.grad_checkpointing = enable
  893. @torch.jit.ignore
  894. def get_classifier(self) -> nn.Module:
  895. return self.head.classifier[-1]
  896. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  897. self.num_classes = num_classes
  898. self.head.reset(num_classes, global_pool)
  899. def forward_intermediates(
  900. self,
  901. x: torch.Tensor,
  902. indices: Optional[Union[int, List[int]]] = None,
  903. norm: bool = False,
  904. stop_early: bool = False,
  905. output_fmt: str = 'NCHW',
  906. intermediates_only: bool = False,
  907. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  908. """ Forward features that returns intermediates.
  909. Args:
  910. x: Input image tensor
  911. indices: Take last n blocks if int, all if None, select matching indices if sequence
  912. norm: Apply norm layer to compatible intermediates
  913. stop_early: Stop iterating over blocks when last desired intermediate hit
  914. output_fmt: Shape of intermediate feature outputs
  915. intermediates_only: Only return intermediate features
  916. Returns:
  917. """
  918. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  919. intermediates = []
  920. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  921. # forward pass
  922. x = self.stem(x)
  923. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  924. stages = self.stages
  925. else:
  926. stages = self.stages[:max_index + 1]
  927. for feat_idx, stage in enumerate(stages):
  928. if self.grad_checkpointing and not torch.jit.is_scripting():
  929. x = checkpoint_seq(stages, x)
  930. else:
  931. x = stage(x)
  932. if feat_idx in take_indices:
  933. intermediates.append(x)
  934. if intermediates_only:
  935. return intermediates
  936. return x, intermediates
  937. def prune_intermediate_layers(
  938. self,
  939. indices: Union[int, List[int]] = 1,
  940. prune_norm: bool = False,
  941. prune_head: bool = True,
  942. ):
  943. """ Prune layers not required for specified intermediates.
  944. """
  945. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  946. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  947. if prune_head:
  948. self.reset_classifier(0, '')
  949. return take_indices
  950. def forward_features(self, x):
  951. x = self.stem(x)
  952. if self.grad_checkpointing and not torch.jit.is_scripting():
  953. x = checkpoint_seq(self.stages, x)
  954. else:
  955. x = self.stages(x)
  956. return x
  957. def forward_head(self, x, pre_logits: bool = False):
  958. return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  959. def forward(self, x):
  960. x = self.forward_features(x)
  961. x = self.forward_head(x)
  962. return x
  963. def _cfg(url='', **kwargs):
  964. return {
  965. 'url': url,
  966. 'num_classes': 1000,
  967. 'mean': IMAGENET_DEFAULT_MEAN,
  968. 'std': IMAGENET_DEFAULT_STD,
  969. 'first_conv': 'stem.in_conv.conv',
  970. 'classifier': 'head.classifier.4',
  971. 'crop_pct': 0.95,
  972. 'license': 'apache-2.0',
  973. 'input_size': (3, 224, 224),
  974. 'pool_size': (7, 7),
  975. **kwargs,
  976. }
  977. default_cfgs = generate_default_cfgs({
  978. 'efficientvit_b0.r224_in1k': _cfg(
  979. hf_hub_id='timm/',
  980. ),
  981. 'efficientvit_b1.r224_in1k': _cfg(
  982. hf_hub_id='timm/',
  983. ),
  984. 'efficientvit_b1.r256_in1k': _cfg(
  985. hf_hub_id='timm/',
  986. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
  987. ),
  988. 'efficientvit_b1.r288_in1k': _cfg(
  989. hf_hub_id='timm/',
  990. input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
  991. ),
  992. 'efficientvit_b2.r224_in1k': _cfg(
  993. hf_hub_id='timm/',
  994. ),
  995. 'efficientvit_b2.r256_in1k': _cfg(
  996. hf_hub_id='timm/',
  997. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
  998. ),
  999. 'efficientvit_b2.r288_in1k': _cfg(
  1000. hf_hub_id='timm/',
  1001. input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
  1002. ),
  1003. 'efficientvit_b3.r224_in1k': _cfg(
  1004. hf_hub_id='timm/',
  1005. ),
  1006. 'efficientvit_b3.r256_in1k': _cfg(
  1007. hf_hub_id='timm/',
  1008. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
  1009. ),
  1010. 'efficientvit_b3.r288_in1k': _cfg(
  1011. hf_hub_id='timm/',
  1012. input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
  1013. ),
  1014. 'efficientvit_l1.r224_in1k': _cfg(
  1015. hf_hub_id='timm/',
  1016. crop_pct=1.0,
  1017. ),
  1018. 'efficientvit_l2.r224_in1k': _cfg(
  1019. hf_hub_id='timm/',
  1020. crop_pct=1.0,
  1021. ),
  1022. 'efficientvit_l2.r256_in1k': _cfg(
  1023. hf_hub_id='timm/',
  1024. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
  1025. ),
  1026. 'efficientvit_l2.r288_in1k': _cfg(
  1027. hf_hub_id='timm/',
  1028. input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0,
  1029. ),
  1030. 'efficientvit_l2.r384_in1k': _cfg(
  1031. hf_hub_id='timm/',
  1032. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
  1033. ),
  1034. 'efficientvit_l3.r224_in1k': _cfg(
  1035. hf_hub_id='timm/',
  1036. crop_pct=1.0,
  1037. ),
  1038. 'efficientvit_l3.r256_in1k': _cfg(
  1039. hf_hub_id='timm/',
  1040. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0,
  1041. ),
  1042. 'efficientvit_l3.r320_in1k': _cfg(
  1043. hf_hub_id='timm/',
  1044. input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0,
  1045. ),
  1046. 'efficientvit_l3.r384_in1k': _cfg(
  1047. hf_hub_id='timm/',
  1048. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0,
  1049. ),
  1050. # 'efficientvit_l0_sam.sam': _cfg(
  1051. # # hf_hub_id='timm/',
  1052. # input_size=(3, 512, 512), crop_pct=1.0,
  1053. # num_classes=0,
  1054. # ),
  1055. # 'efficientvit_l1_sam.sam': _cfg(
  1056. # # hf_hub_id='timm/',
  1057. # input_size=(3, 512, 512), crop_pct=1.0,
  1058. # num_classes=0,
  1059. # ),
  1060. # 'efficientvit_l2_sam.sam': _cfg(
  1061. # # hf_hub_id='timm/',f
  1062. # input_size=(3, 512, 512), crop_pct=1.0,
  1063. # num_classes=0,
  1064. # ),
  1065. })
  1066. def _create_efficientvit(variant, pretrained=False, **kwargs):
  1067. out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
  1068. model = build_model_with_cfg(
  1069. EfficientVit,
  1070. variant,
  1071. pretrained,
  1072. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  1073. **kwargs
  1074. )
  1075. return model
  1076. def _create_efficientvit_large(variant, pretrained=False, **kwargs):
  1077. out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
  1078. model = build_model_with_cfg(
  1079. EfficientVitLarge,
  1080. variant,
  1081. pretrained,
  1082. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  1083. **kwargs
  1084. )
  1085. return model
  1086. @register_model
  1087. def efficientvit_b0(pretrained=False, **kwargs):
  1088. model_args = dict(
  1089. widths=(8, 16, 32, 64, 128), depths=(1, 2, 2, 2, 2), head_dim=16, head_widths=(1024, 1280))
  1090. return _create_efficientvit('efficientvit_b0', pretrained=pretrained, **dict(model_args, **kwargs))
  1091. @register_model
  1092. def efficientvit_b1(pretrained=False, **kwargs):
  1093. model_args = dict(
  1094. widths=(16, 32, 64, 128, 256), depths=(1, 2, 3, 3, 4), head_dim=16, head_widths=(1536, 1600))
  1095. return _create_efficientvit('efficientvit_b1', pretrained=pretrained, **dict(model_args, **kwargs))
  1096. @register_model
  1097. def efficientvit_b2(pretrained=False, **kwargs):
  1098. model_args = dict(
  1099. widths=(24, 48, 96, 192, 384), depths=(1, 3, 4, 4, 6), head_dim=32, head_widths=(2304, 2560))
  1100. return _create_efficientvit('efficientvit_b2', pretrained=pretrained, **dict(model_args, **kwargs))
  1101. @register_model
  1102. def efficientvit_b3(pretrained=False, **kwargs):
  1103. model_args = dict(
  1104. widths=(32, 64, 128, 256, 512), depths=(1, 4, 6, 6, 9), head_dim=32, head_widths=(2304, 2560))
  1105. return _create_efficientvit('efficientvit_b3', pretrained=pretrained, **dict(model_args, **kwargs))
  1106. @register_model
  1107. def efficientvit_l1(pretrained=False, **kwargs):
  1108. model_args = dict(
  1109. widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, head_widths=(3072, 3200))
  1110. return _create_efficientvit_large('efficientvit_l1', pretrained=pretrained, **dict(model_args, **kwargs))
  1111. @register_model
  1112. def efficientvit_l2(pretrained=False, **kwargs):
  1113. model_args = dict(
  1114. widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, head_widths=(3072, 3200))
  1115. return _create_efficientvit_large('efficientvit_l2', pretrained=pretrained, **dict(model_args, **kwargs))
  1116. @register_model
  1117. def efficientvit_l3(pretrained=False, **kwargs):
  1118. model_args = dict(
  1119. widths=(64, 128, 256, 512, 1024), depths=(1, 2, 2, 8, 8), head_dim=32, head_widths=(6144, 6400))
  1120. return _create_efficientvit_large('efficientvit_l3', pretrained=pretrained, **dict(model_args, **kwargs))
  1121. # FIXME will wait for v2 SAM models which are pending
  1122. # @register_model
  1123. # def efficientvit_l0_sam(pretrained=False, **kwargs):
  1124. # # only backbone for segment-anything-model weights
  1125. # model_args = dict(
  1126. # widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 4, 4), head_dim=32, num_classes=0, norm_eps=1e-6)
  1127. # return _create_efficientvit_large('efficientvit_l0_sam', pretrained=pretrained, **dict(model_args, **kwargs))
  1128. #
  1129. #
  1130. # @register_model
  1131. # def efficientvit_l1_sam(pretrained=False, **kwargs):
  1132. # # only backbone for segment-anything-model weights
  1133. # model_args = dict(
  1134. # widths=(32, 64, 128, 256, 512), depths=(1, 1, 1, 6, 6), head_dim=32, num_classes=0, norm_eps=1e-6)
  1135. # return _create_efficientvit_large('efficientvit_l1_sam', pretrained=pretrained, **dict(model_args, **kwargs))
  1136. #
  1137. #
  1138. # @register_model
  1139. # def efficientvit_l2_sam(pretrained=False, **kwargs):
  1140. # # only backbone for segment-anything-model weights
  1141. # model_args = dict(
  1142. # widths=(32, 64, 128, 256, 512), depths=(1, 2, 2, 8, 8), head_dim=32, num_classes=0, norm_eps=1e-6)
  1143. # return _create_efficientvit_large('efficientvit_l2_sam', pretrained=pretrained, **dict(model_args, **kwargs))