convnext.py 59 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409
  1. """ ConvNeXt
  2. Papers:
  3. * `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
  4. @Article{liu2022convnet,
  5. author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
  6. title = {A ConvNet for the 2020s},
  7. journal = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  8. year = {2022},
  9. }
  10. * `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
  11. @article{Woo2023ConvNeXtV2,
  12. title={ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders},
  13. author={Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon and Saining Xie},
  14. year={2023},
  15. journal={arXiv preprint arXiv:2301.00808},
  16. }
  17. Original code and weights from:
  18. * https://github.com/facebookresearch/ConvNeXt, original copyright below
  19. * https://github.com/facebookresearch/ConvNeXt-V2, original copyright below
  20. Model defs atto, femto, pico, nano and _ols / _hnf variants are timm originals.
  21. Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
  22. """
  23. # ConvNeXt
  24. # Copyright (c) Meta Platforms, Inc. and affiliates.
  25. # All rights reserved.
  26. # This source code is licensed under the MIT license
  27. # ConvNeXt-V2
  28. # Copyright (c) Meta Platforms, Inc. and affiliates.
  29. # All rights reserved.
  30. # This source code is licensed under the license found in the
  31. # LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
  32. # No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
  33. from functools import partial
  34. from typing import Callable, Dict, List, Optional, Tuple, Union
  35. import torch
  36. import torch.nn as nn
  37. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
  38. from timm.layers import (
  39. trunc_normal_,
  40. AvgPool2dSame,
  41. DropPath,
  42. calculate_drop_path_rates,
  43. Mlp,
  44. GlobalResponseNormMlp,
  45. LayerNorm2d,
  46. LayerNorm,
  47. RmsNorm2d,
  48. RmsNorm,
  49. SimpleNorm2d,
  50. SimpleNorm,
  51. create_conv2d,
  52. get_act_layer,
  53. get_norm_layer,
  54. make_divisible,
  55. to_ntuple,
  56. NormMlpClassifierHead,
  57. ClassifierHead,
  58. )
  59. from ._builder import build_model_with_cfg
  60. from ._features import feature_take_indices
  61. from ._manipulate import named_apply, checkpoint_seq
  62. from ._registry import generate_default_cfgs, register_model, register_model_deprecations
  63. __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this
  64. class Downsample(nn.Module):
  65. """Downsample module for ConvNeXt."""
  66. def __init__(
  67. self,
  68. in_chs: int,
  69. out_chs: int,
  70. stride: int = 1,
  71. dilation: int = 1,
  72. device=None,
  73. dtype=None,
  74. ) -> None:
  75. """Initialize Downsample module.
  76. Args:
  77. in_chs: Number of input channels.
  78. out_chs: Number of output channels.
  79. stride: Stride for downsampling.
  80. dilation: Dilation rate.
  81. """
  82. dd = {'device': device, 'dtype': dtype}
  83. super().__init__()
  84. avg_stride = stride if dilation == 1 else 1
  85. if stride > 1 or dilation > 1:
  86. avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
  87. self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
  88. else:
  89. self.pool = nn.Identity()
  90. if in_chs != out_chs:
  91. self.conv = create_conv2d(in_chs, out_chs, 1, stride=1, **dd)
  92. else:
  93. self.conv = nn.Identity()
  94. def forward(self, x: torch.Tensor) -> torch.Tensor:
  95. """Forward pass."""
  96. x = self.pool(x)
  97. x = self.conv(x)
  98. return x
  99. class ConvNeXtBlock(nn.Module):
  100. """ConvNeXt Block.
  101. There are two equivalent implementations:
  102. (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
  103. (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
  104. Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
  105. choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
  106. is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
  107. """
  108. def __init__(
  109. self,
  110. in_chs: int,
  111. out_chs: Optional[int] = None,
  112. kernel_size: int = 7,
  113. stride: int = 1,
  114. dilation: Union[int, Tuple[int, int]] = (1, 1),
  115. mlp_ratio: float = 4,
  116. conv_mlp: bool = False,
  117. conv_bias: bool = True,
  118. use_grn: bool = False,
  119. ls_init_value: Optional[float] = 1e-6,
  120. act_layer: Union[str, Callable] = 'gelu',
  121. norm_layer: Optional[Callable] = None,
  122. drop_path: float = 0.,
  123. device=None,
  124. dtype=None,
  125. ):
  126. """
  127. Args:
  128. in_chs: Block input channels.
  129. out_chs: Block output channels (same as in_chs if None).
  130. kernel_size: Depthwise convolution kernel size.
  131. stride: Stride of depthwise convolution.
  132. dilation: Tuple specifying input and output dilation of block.
  133. mlp_ratio: MLP expansion ratio.
  134. conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
  135. conv_bias: Apply bias for all convolution (linear) layers.
  136. use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
  137. ls_init_value: Layer-scale init values, layer-scale applied if not None.
  138. act_layer: Activation layer.
  139. norm_layer: Normalization layer (defaults to LN if not specified).
  140. drop_path: Stochastic depth probability.
  141. """
  142. dd = {'device': device, 'dtype': dtype}
  143. super().__init__()
  144. out_chs = out_chs or in_chs
  145. dilation = to_ntuple(2)(dilation)
  146. act_layer = get_act_layer(act_layer)
  147. if not norm_layer:
  148. norm_layer = LayerNorm2d if conv_mlp else LayerNorm
  149. mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
  150. self.use_conv_mlp = conv_mlp
  151. self.conv_dw = create_conv2d(
  152. in_chs,
  153. out_chs,
  154. kernel_size=kernel_size,
  155. stride=stride,
  156. dilation=dilation[0],
  157. depthwise=True,
  158. bias=conv_bias,
  159. **dd,
  160. )
  161. self.norm = norm_layer(out_chs, **dd)
  162. self.mlp = mlp_layer(
  163. out_chs,
  164. int(mlp_ratio * out_chs),
  165. act_layer=act_layer,
  166. **dd,
  167. )
  168. self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs, **dd)) if ls_init_value is not None else None
  169. if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
  170. self.shortcut = Downsample(in_chs, out_chs, stride=stride, dilation=dilation[0], **dd)
  171. else:
  172. self.shortcut = nn.Identity()
  173. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  174. def forward(self, x: torch.Tensor) -> torch.Tensor:
  175. """Forward pass."""
  176. shortcut = x
  177. x = self.conv_dw(x)
  178. if self.use_conv_mlp:
  179. x = self.norm(x)
  180. x = self.mlp(x)
  181. else:
  182. x = x.permute(0, 2, 3, 1)
  183. x = self.norm(x)
  184. x = self.mlp(x)
  185. x = x.permute(0, 3, 1, 2)
  186. if self.gamma is not None:
  187. x = x.mul(self.gamma.reshape(1, -1, 1, 1))
  188. x = self.drop_path(x) + self.shortcut(shortcut)
  189. return x
  190. class ConvNeXtStage(nn.Module):
  191. """ConvNeXt stage (multiple blocks)."""
  192. def __init__(
  193. self,
  194. in_chs: int,
  195. out_chs: int,
  196. kernel_size: int = 7,
  197. stride: int = 2,
  198. depth: int = 2,
  199. dilation: Tuple[int, int] = (1, 1),
  200. drop_path_rates: Optional[List[float]] = None,
  201. ls_init_value: float = 1.0,
  202. conv_mlp: bool = False,
  203. conv_bias: bool = True,
  204. use_grn: bool = False,
  205. act_layer: Union[str, Callable] = 'gelu',
  206. norm_layer: Optional[Callable] = None,
  207. norm_layer_cl: Optional[Callable] = None,
  208. device=None,
  209. dtype=None,
  210. ) -> None:
  211. """Initialize ConvNeXt stage.
  212. Args:
  213. in_chs: Number of input channels.
  214. out_chs: Number of output channels.
  215. kernel_size: Kernel size for depthwise convolution.
  216. stride: Stride for downsampling.
  217. depth: Number of blocks in stage.
  218. dilation: Dilation rates.
  219. drop_path_rates: Drop path rates for each block.
  220. ls_init_value: Initial value for layer scale.
  221. conv_mlp: Use convolutional MLP.
  222. conv_bias: Use bias in convolutions.
  223. use_grn: Use global response normalization.
  224. act_layer: Activation layer.
  225. norm_layer: Normalization layer.
  226. norm_layer_cl: Normalization layer for channels last.
  227. """
  228. dd = {'device': device, 'dtype': dtype}
  229. super().__init__()
  230. self.grad_checkpointing = False
  231. if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
  232. ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
  233. pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used
  234. self.downsample = nn.Sequential(
  235. norm_layer(in_chs, **dd),
  236. create_conv2d(
  237. in_chs,
  238. out_chs,
  239. kernel_size=ds_ks,
  240. stride=stride,
  241. dilation=dilation[0],
  242. padding=pad,
  243. bias=conv_bias,
  244. **dd,
  245. ),
  246. )
  247. in_chs = out_chs
  248. else:
  249. self.downsample = nn.Identity()
  250. drop_path_rates = drop_path_rates or [0.] * depth
  251. stage_blocks = []
  252. for i in range(depth):
  253. stage_blocks.append(ConvNeXtBlock(
  254. in_chs=in_chs,
  255. out_chs=out_chs,
  256. kernel_size=kernel_size,
  257. dilation=dilation[1],
  258. drop_path=drop_path_rates[i],
  259. ls_init_value=ls_init_value,
  260. conv_mlp=conv_mlp,
  261. conv_bias=conv_bias,
  262. use_grn=use_grn,
  263. act_layer=act_layer,
  264. norm_layer=norm_layer if conv_mlp else norm_layer_cl,
  265. **dd,
  266. ))
  267. in_chs = out_chs
  268. self.blocks = nn.Sequential(*stage_blocks)
  269. def forward(self, x: torch.Tensor) -> torch.Tensor:
  270. """Forward pass."""
  271. x = self.downsample(x)
  272. if self.grad_checkpointing and not torch.jit.is_scripting():
  273. x = checkpoint_seq(self.blocks, x)
  274. else:
  275. x = self.blocks(x)
  276. return x
  277. # map of norm layers with NCHW (2D) and channels last variants
  278. _NORM_MAP = {
  279. 'layernorm': (LayerNorm2d, LayerNorm),
  280. 'layernorm2d': (LayerNorm2d, LayerNorm),
  281. 'simplenorm': (SimpleNorm2d, SimpleNorm),
  282. 'simplenorm2d': (SimpleNorm2d, SimpleNorm),
  283. 'rmsnorm': (RmsNorm2d, RmsNorm),
  284. 'rmsnorm2d': (RmsNorm2d, RmsNorm),
  285. }
  286. def _get_norm_layers(norm_layer: Union[Callable, str], conv_mlp: bool, norm_eps: float):
  287. norm_layer = norm_layer or 'layernorm'
  288. if norm_layer in _NORM_MAP:
  289. norm_layer_cl = _NORM_MAP[norm_layer][0] if conv_mlp else _NORM_MAP[norm_layer][1]
  290. norm_layer = _NORM_MAP[norm_layer][0]
  291. if norm_eps is not None:
  292. norm_layer = partial(norm_layer, eps=norm_eps)
  293. norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
  294. else:
  295. assert conv_mlp, \
  296. 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
  297. norm_layer = get_norm_layer(norm_layer)
  298. norm_layer_cl = norm_layer
  299. if norm_eps is not None:
  300. norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
  301. return norm_layer, norm_layer_cl
  302. class ConvNeXt(nn.Module):
  303. """ConvNeXt model architecture.
  304. A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
  305. """
  306. def __init__(
  307. self,
  308. in_chans: int = 3,
  309. num_classes: int = 1000,
  310. global_pool: str = 'avg',
  311. output_stride: int = 32,
  312. depths: Tuple[int, ...] = (3, 3, 9, 3),
  313. dims: Tuple[int, ...] = (96, 192, 384, 768),
  314. kernel_sizes: Union[int, Tuple[int, ...]] = 7,
  315. ls_init_value: Optional[float] = 1e-6,
  316. stem_type: str = 'patch',
  317. patch_size: int = 4,
  318. head_init_scale: float = 1.,
  319. head_norm_first: bool = False,
  320. head_hidden_size: Optional[int] = None,
  321. conv_mlp: bool = False,
  322. conv_bias: bool = True,
  323. use_grn: bool = False,
  324. act_layer: Union[str, Callable] = 'gelu',
  325. norm_layer: Optional[Union[str, Callable]] = None,
  326. norm_eps: Optional[float] = None,
  327. drop_rate: float = 0.,
  328. drop_path_rate: float = 0.,
  329. device=None,
  330. dtype=None,
  331. ):
  332. """
  333. Args:
  334. in_chans: Number of input image channels.
  335. num_classes: Number of classes for classification head.
  336. global_pool: Global pooling type.
  337. output_stride: Output stride of network, one of (8, 16, 32).
  338. depths: Number of blocks at each stage.
  339. dims: Feature dimension at each stage.
  340. kernel_sizes: Depthwise convolution kernel-sizes for each stage.
  341. ls_init_value: Init value for Layer Scale, disabled if None.
  342. stem_type: Type of stem.
  343. patch_size: Stem patch size for patch stem.
  344. head_init_scale: Init scaling value for classifier weights and biases.
  345. head_norm_first: Apply normalization before global pool + head.
  346. head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
  347. conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
  348. conv_bias: Use bias layers w/ all convolutions.
  349. use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
  350. act_layer: Activation layer type.
  351. norm_layer: Normalization layer type.
  352. drop_rate: Head pre-classifier dropout rate.
  353. drop_path_rate: Stochastic depth drop rate.
  354. """
  355. super().__init__()
  356. dd = {'device': device, 'dtype': dtype}
  357. assert output_stride in (8, 16, 32)
  358. kernel_sizes = to_ntuple(4)(kernel_sizes)
  359. norm_layer, norm_layer_cl = _get_norm_layers(norm_layer, conv_mlp, norm_eps)
  360. act_layer = get_act_layer(act_layer)
  361. self.num_classes = num_classes
  362. self.drop_rate = drop_rate
  363. self.feature_info = []
  364. assert stem_type in ('patch', 'overlap', 'overlap_tiered', 'overlap_act')
  365. if stem_type == 'patch':
  366. # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
  367. self.stem = nn.Sequential(
  368. nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias, **dd),
  369. norm_layer(dims[0], **dd),
  370. )
  371. stem_stride = patch_size
  372. else:
  373. mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
  374. self.stem = nn.Sequential(*filter(None, [
  375. nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias, **dd),
  376. act_layer() if 'act' in stem_type else None,
  377. nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias, **dd),
  378. norm_layer(dims[0], **dd),
  379. ]))
  380. stem_stride = 4
  381. self.stages = nn.Sequential()
  382. dp_rates = calculate_drop_path_rates(drop_path_rate, depths, stagewise=True)
  383. stages = []
  384. prev_chs = dims[0]
  385. curr_stride = stem_stride
  386. dilation = 1
  387. # 4 feature resolution stages, each consisting of multiple residual blocks
  388. for i in range(4):
  389. stride = 2 if curr_stride == 2 or i > 0 else 1
  390. if curr_stride >= output_stride and stride > 1:
  391. dilation *= stride
  392. stride = 1
  393. curr_stride *= stride
  394. first_dilation = 1 if dilation in (1, 2) else 2
  395. out_chs = dims[i]
  396. stages.append(ConvNeXtStage(
  397. prev_chs,
  398. out_chs,
  399. kernel_size=kernel_sizes[i],
  400. stride=stride,
  401. dilation=(first_dilation, dilation),
  402. depth=depths[i],
  403. drop_path_rates=dp_rates[i],
  404. ls_init_value=ls_init_value,
  405. conv_mlp=conv_mlp,
  406. conv_bias=conv_bias,
  407. use_grn=use_grn,
  408. act_layer=act_layer,
  409. norm_layer=norm_layer,
  410. norm_layer_cl=norm_layer_cl,
  411. **dd,
  412. ))
  413. prev_chs = out_chs
  414. # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
  415. self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
  416. self.stages = nn.Sequential(*stages)
  417. self.num_features = self.head_hidden_size = prev_chs
  418. # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
  419. # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
  420. if head_norm_first:
  421. assert not head_hidden_size
  422. self.norm_pre = norm_layer(self.num_features, **dd)
  423. self.head = ClassifierHead(
  424. self.num_features,
  425. num_classes,
  426. pool_type=global_pool,
  427. drop_rate=self.drop_rate,
  428. **dd,
  429. )
  430. else:
  431. self.norm_pre = nn.Identity()
  432. self.head = NormMlpClassifierHead(
  433. self.num_features,
  434. num_classes,
  435. hidden_size=head_hidden_size,
  436. pool_type=global_pool,
  437. drop_rate=self.drop_rate,
  438. norm_layer=norm_layer,
  439. act_layer='gelu',
  440. **dd,
  441. )
  442. self.head_hidden_size = self.head.num_features
  443. named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
  444. @torch.jit.ignore
  445. def group_matcher(self, coarse: bool = False) -> Dict[str, Union[str, List]]:
  446. """Create regex patterns for parameter grouping.
  447. Args:
  448. coarse: Use coarse grouping.
  449. Returns:
  450. Dictionary mapping group names to regex patterns.
  451. """
  452. return dict(
  453. stem=r'^stem',
  454. blocks=r'^stages\.(\d+)' if coarse else [
  455. (r'^stages\.(\d+)\.downsample', (0,)), # blocks
  456. (r'^stages\.(\d+)\.blocks\.(\d+)', None),
  457. (r'^norm_pre', (99999,))
  458. ]
  459. )
  460. @torch.jit.ignore
  461. def set_grad_checkpointing(self, enable: bool = True) -> None:
  462. """Enable or disable gradient checkpointing.
  463. Args:
  464. enable: Whether to enable gradient checkpointing.
  465. """
  466. for s in self.stages:
  467. s.grad_checkpointing = enable
  468. @torch.jit.ignore
  469. def get_classifier(self) -> nn.Module:
  470. """Get the classifier module."""
  471. return self.head.fc
  472. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  473. """Reset the classifier head.
  474. Args:
  475. num_classes: Number of classes for new classifier.
  476. global_pool: Global pooling type.
  477. """
  478. self.num_classes = num_classes
  479. self.head.reset(num_classes, global_pool)
  480. def forward_intermediates(
  481. self,
  482. x: torch.Tensor,
  483. indices: Optional[Union[int, List[int]]] = None,
  484. norm: bool = False,
  485. stop_early: bool = False,
  486. output_fmt: str = 'NCHW',
  487. intermediates_only: bool = False,
  488. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  489. """Forward features that returns intermediates.
  490. Args:
  491. x: Input image tensor.
  492. indices: Take last n blocks if int, all if None, select matching indices if sequence.
  493. norm: Apply norm layer to compatible intermediates.
  494. stop_early: Stop iterating over blocks when last desired intermediate hit.
  495. output_fmt: Shape of intermediate feature outputs.
  496. intermediates_only: Only return intermediate features.
  497. Returns:
  498. List of intermediate features or tuple of (final features, intermediates).
  499. """
  500. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  501. intermediates = []
  502. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  503. # forward pass
  504. x = self.stem(x)
  505. last_idx = len(self.stages) - 1
  506. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  507. stages = self.stages
  508. else:
  509. stages = self.stages[:max_index + 1]
  510. for feat_idx, stage in enumerate(stages):
  511. x = stage(x)
  512. if feat_idx in take_indices:
  513. if norm and feat_idx == last_idx:
  514. intermediates.append(self.norm_pre(x))
  515. else:
  516. intermediates.append(x)
  517. if intermediates_only:
  518. return intermediates
  519. if feat_idx == last_idx:
  520. x = self.norm_pre(x)
  521. return x, intermediates
  522. def prune_intermediate_layers(
  523. self,
  524. indices: Union[int, List[int]] = 1,
  525. prune_norm: bool = False,
  526. prune_head: bool = True,
  527. ) -> List[int]:
  528. """Prune layers not required for specified intermediates.
  529. Args:
  530. indices: Indices of intermediate layers to keep.
  531. prune_norm: Whether to prune normalization layer.
  532. prune_head: Whether to prune the classifier head.
  533. Returns:
  534. List of indices that were kept.
  535. """
  536. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  537. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  538. if prune_norm:
  539. self.norm_pre = nn.Identity()
  540. if prune_head:
  541. self.reset_classifier(0, '')
  542. return take_indices
  543. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  544. """Forward pass through feature extraction layers."""
  545. x = self.stem(x)
  546. x = self.stages(x)
  547. x = self.norm_pre(x)
  548. return x
  549. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  550. """Forward pass through classifier head.
  551. Args:
  552. x: Feature tensor.
  553. pre_logits: Return features before final classifier.
  554. Returns:
  555. Output tensor.
  556. """
  557. return self.head(x, pre_logits=True) if pre_logits else self.head(x)
  558. def forward(self, x: torch.Tensor) -> torch.Tensor:
  559. """Forward pass."""
  560. x = self.forward_features(x)
  561. x = self.forward_head(x)
  562. return x
  563. def _init_weights(module: nn.Module, name: Optional[str] = None, head_init_scale: float = 1.0) -> None:
  564. """Initialize model weights.
  565. Args:
  566. module: Module to initialize.
  567. name: Module name.
  568. head_init_scale: Scale factor for head initialization.
  569. """
  570. if isinstance(module, nn.Conv2d):
  571. trunc_normal_(module.weight, std=.02)
  572. if module.bias is not None:
  573. nn.init.zeros_(module.bias)
  574. elif isinstance(module, nn.Linear):
  575. trunc_normal_(module.weight, std=.02)
  576. nn.init.zeros_(module.bias)
  577. if name and 'head.' in name:
  578. module.weight.data.mul_(head_init_scale)
  579. module.bias.data.mul_(head_init_scale)
  580. def checkpoint_filter_fn(state_dict, model):
  581. """ Remap FB checkpoints -> timm """
  582. if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
  583. return state_dict # non-FB checkpoint
  584. if 'model' in state_dict:
  585. state_dict = state_dict['model']
  586. out_dict = {}
  587. if 'visual.trunk.stem.0.weight' in state_dict:
  588. out_dict = {k.replace('visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('visual.trunk.')}
  589. if 'visual.head.proj.weight' in state_dict:
  590. out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
  591. out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
  592. elif 'visual.head.mlp.fc1.weight' in state_dict:
  593. out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight']
  594. out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias']
  595. out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight']
  596. out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0])
  597. return out_dict
  598. import re
  599. for k, v in state_dict.items():
  600. k = k.replace('downsample_layers.0.', 'stem.')
  601. k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
  602. k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
  603. k = k.replace('dwconv', 'conv_dw')
  604. k = k.replace('pwconv', 'mlp.fc')
  605. if 'grn' in k:
  606. k = k.replace('grn.beta', 'mlp.grn.bias')
  607. k = k.replace('grn.gamma', 'mlp.grn.weight')
  608. v = v.reshape(v.shape[-1])
  609. k = k.replace('head.', 'head.fc.')
  610. if k.startswith('norm.'):
  611. k = k.replace('norm', 'head.norm')
  612. if v.ndim == 2 and 'head' not in k:
  613. model_shape = model.state_dict()[k].shape
  614. v = v.reshape(model_shape)
  615. out_dict[k] = v
  616. return out_dict
  617. def _create_convnext(variant, pretrained=False, **kwargs):
  618. if kwargs.get('pretrained_cfg', '') == 'fcmae':
  619. # NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
  620. # This is workaround loading with num_classes=0 w/o removing norm-layer.
  621. kwargs.setdefault('pretrained_strict', False)
  622. model = build_model_with_cfg(
  623. ConvNeXt, variant, pretrained,
  624. pretrained_filter_fn=checkpoint_filter_fn,
  625. feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
  626. **kwargs)
  627. return model
  628. def _cfg(url='', **kwargs):
  629. return {
  630. 'url': url,
  631. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  632. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  633. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  634. 'first_conv': 'stem.0', 'classifier': 'head.fc',
  635. 'license': 'apache-2.0', **kwargs
  636. }
  637. def _cfgv2(url='', **kwargs):
  638. return {
  639. 'url': url,
  640. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  641. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  642. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  643. 'first_conv': 'stem.0', 'classifier': 'head.fc',
  644. 'license': 'cc-by-nc-4.0', 'paper_ids': 'arXiv:2301.00808',
  645. 'paper_name': 'ConvNeXt-V2: Co-designing and Scaling ConvNets with Masked Autoencoders',
  646. 'origin_url': 'https://github.com/facebookresearch/ConvNeXt-V2',
  647. **kwargs
  648. }
  649. default_cfgs = generate_default_cfgs({
  650. # timm specific variants
  651. 'convnext_tiny.in12k_ft_in1k': _cfg(
  652. hf_hub_id='timm/',
  653. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
  654. 'convnext_small.in12k_ft_in1k': _cfg(
  655. hf_hub_id='timm/',
  656. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
  657. 'convnext_zepto_rms.ra4_e3600_r224_in1k': _cfg(
  658. hf_hub_id='timm/',
  659. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
  660. 'convnext_zepto_rms_ols.ra4_e3600_r224_in1k': _cfg(
  661. hf_hub_id='timm/',
  662. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  663. crop_pct=0.9),
  664. 'convnext_atto.d2_in1k': _cfg(
  665. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
  666. hf_hub_id='timm/',
  667. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  668. 'convnext_atto_ols.a2_in1k': _cfg(
  669. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
  670. hf_hub_id='timm/',
  671. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  672. 'convnext_atto_rms.untrained': _cfg(
  673. #hf_hub_id='timm/',
  674. test_input_size=(3, 256, 256), test_crop_pct=0.95),
  675. 'convnext_femto.d1_in1k': _cfg(
  676. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
  677. hf_hub_id='timm/',
  678. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  679. 'convnext_femto_ols.d1_in1k': _cfg(
  680. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
  681. hf_hub_id='timm/',
  682. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  683. 'convnext_pico.d1_in1k': _cfg(
  684. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
  685. hf_hub_id='timm/',
  686. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  687. 'convnext_pico_ols.d1_in1k': _cfg(
  688. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
  689. hf_hub_id='timm/',
  690. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
  691. 'convnext_nano.in12k_ft_in1k': _cfg(
  692. hf_hub_id='timm/',
  693. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
  694. 'convnext_nano.d1h_in1k': _cfg(
  695. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
  696. hf_hub_id='timm/',
  697. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
  698. 'convnext_nano_ols.d1h_in1k': _cfg(
  699. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
  700. hf_hub_id='timm/',
  701. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
  702. 'convnext_tiny_hnf.a2h_in1k': _cfg(
  703. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
  704. hf_hub_id='timm/',
  705. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
  706. 'convnext_nano.r384_in12k_ft_in1k': _cfg(
  707. hf_hub_id='timm/',
  708. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  709. 'convnext_tiny.in12k_ft_in1k_384': _cfg(
  710. hf_hub_id='timm/',
  711. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  712. 'convnext_small.in12k_ft_in1k_384': _cfg(
  713. hf_hub_id='timm/',
  714. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  715. 'convnext_nano.in12k': _cfg(
  716. hf_hub_id='timm/',
  717. crop_pct=0.95, num_classes=11821),
  718. 'convnext_nano.r384_in12k': _cfg(
  719. hf_hub_id='timm/',
  720. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=11821),
  721. 'convnext_nano.r384_ad_in12k': _cfg(
  722. hf_hub_id='timm/',
  723. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=11821),
  724. 'convnext_tiny.in12k': _cfg(
  725. hf_hub_id='timm/',
  726. crop_pct=0.95, num_classes=11821),
  727. 'convnext_small.in12k': _cfg(
  728. hf_hub_id='timm/',
  729. crop_pct=0.95, num_classes=11821),
  730. 'convnext_tiny.fb_in22k_ft_in1k': _cfg(
  731. url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
  732. hf_hub_id='timm/',
  733. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  734. 'convnext_small.fb_in22k_ft_in1k': _cfg(
  735. url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
  736. hf_hub_id='timm/',
  737. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  738. 'convnext_base.fb_in22k_ft_in1k': _cfg(
  739. url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
  740. hf_hub_id='timm/',
  741. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  742. 'convnext_large.fb_in22k_ft_in1k': _cfg(
  743. url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
  744. hf_hub_id='timm/',
  745. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  746. 'convnext_xlarge.fb_in22k_ft_in1k': _cfg(
  747. url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
  748. hf_hub_id='timm/',
  749. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  750. 'convnext_tiny.fb_in1k': _cfg(
  751. url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
  752. hf_hub_id='timm/',
  753. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  754. 'convnext_small.fb_in1k': _cfg(
  755. url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
  756. hf_hub_id='timm/',
  757. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  758. 'convnext_base.fb_in1k': _cfg(
  759. url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
  760. hf_hub_id='timm/',
  761. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  762. 'convnext_large.fb_in1k': _cfg(
  763. url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
  764. hf_hub_id='timm/',
  765. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  766. 'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
  767. url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
  768. hf_hub_id='timm/',
  769. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  770. 'convnext_small.fb_in22k_ft_in1k_384': _cfg(
  771. url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
  772. hf_hub_id='timm/',
  773. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  774. 'convnext_base.fb_in22k_ft_in1k_384': _cfg(
  775. url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
  776. hf_hub_id='timm/',
  777. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  778. 'convnext_large.fb_in22k_ft_in1k_384': _cfg(
  779. url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
  780. hf_hub_id='timm/',
  781. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  782. 'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg(
  783. url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
  784. hf_hub_id='timm/',
  785. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  786. 'convnext_tiny.fb_in22k': _cfg(
  787. url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
  788. hf_hub_id='timm/',
  789. num_classes=21841),
  790. 'convnext_small.fb_in22k': _cfg(
  791. url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
  792. hf_hub_id='timm/',
  793. num_classes=21841),
  794. 'convnext_base.fb_in22k': _cfg(
  795. url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
  796. hf_hub_id='timm/',
  797. num_classes=21841),
  798. 'convnext_large.fb_in22k': _cfg(
  799. url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
  800. hf_hub_id='timm/',
  801. num_classes=21841),
  802. 'convnext_xlarge.fb_in22k': _cfg(
  803. url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
  804. hf_hub_id='timm/',
  805. num_classes=21841),
  806. 'convnextv2_nano.fcmae_ft_in22k_in1k': _cfgv2(
  807. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt',
  808. hf_hub_id='timm/',
  809. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  810. 'convnextv2_nano.fcmae_ft_in22k_in1k_384': _cfgv2(
  811. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt',
  812. hf_hub_id='timm/',
  813. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  814. 'convnextv2_tiny.fcmae_ft_in22k_in1k': _cfgv2(
  815. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt",
  816. hf_hub_id='timm/',
  817. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  818. 'convnextv2_tiny.fcmae_ft_in22k_in1k_384': _cfgv2(
  819. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt",
  820. hf_hub_id='timm/',
  821. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  822. 'convnextv2_base.fcmae_ft_in22k_in1k': _cfgv2(
  823. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt",
  824. hf_hub_id='timm/',
  825. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  826. 'convnextv2_base.fcmae_ft_in22k_in1k_384': _cfgv2(
  827. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt",
  828. hf_hub_id='timm/',
  829. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  830. 'convnextv2_large.fcmae_ft_in22k_in1k': _cfgv2(
  831. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt",
  832. hf_hub_id='timm/',
  833. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  834. 'convnextv2_large.fcmae_ft_in22k_in1k_384': _cfgv2(
  835. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt",
  836. hf_hub_id='timm/',
  837. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  838. 'convnextv2_huge.fcmae_ft_in22k_in1k_384': _cfgv2(
  839. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt",
  840. hf_hub_id='timm/',
  841. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  842. 'convnextv2_huge.fcmae_ft_in22k_in1k_512': _cfgv2(
  843. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt",
  844. hf_hub_id='timm/',
  845. input_size=(3, 512, 512), pool_size=(15, 15), crop_pct=1.0, crop_mode='squash'),
  846. 'convnextv2_atto.fcmae_ft_in1k': _cfgv2(
  847. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt',
  848. hf_hub_id='timm/',
  849. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  850. 'convnextv2_femto.fcmae_ft_in1k': _cfgv2(
  851. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt',
  852. hf_hub_id='timm/',
  853. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  854. 'convnextv2_pico.fcmae_ft_in1k': _cfgv2(
  855. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt',
  856. hf_hub_id='timm/',
  857. test_input_size=(3, 288, 288), test_crop_pct=0.95),
  858. 'convnextv2_nano.fcmae_ft_in1k': _cfgv2(
  859. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt',
  860. hf_hub_id='timm/',
  861. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  862. 'convnextv2_tiny.fcmae_ft_in1k': _cfgv2(
  863. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt",
  864. hf_hub_id='timm/',
  865. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  866. 'convnextv2_base.fcmae_ft_in1k': _cfgv2(
  867. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt",
  868. hf_hub_id='timm/',
  869. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  870. 'convnextv2_large.fcmae_ft_in1k': _cfgv2(
  871. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt",
  872. hf_hub_id='timm/',
  873. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  874. 'convnextv2_huge.fcmae_ft_in1k': _cfgv2(
  875. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt",
  876. hf_hub_id='timm/',
  877. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  878. 'convnextv2_atto.fcmae': _cfgv2(
  879. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_atto_1k_224_fcmae.pt',
  880. hf_hub_id='timm/',
  881. num_classes=0),
  882. 'convnextv2_femto.fcmae': _cfgv2(
  883. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_femto_1k_224_fcmae.pt',
  884. hf_hub_id='timm/',
  885. num_classes=0),
  886. 'convnextv2_pico.fcmae': _cfgv2(
  887. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_pico_1k_224_fcmae.pt',
  888. hf_hub_id='timm/',
  889. num_classes=0),
  890. 'convnextv2_nano.fcmae': _cfgv2(
  891. url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_nano_1k_224_fcmae.pt',
  892. hf_hub_id='timm/',
  893. num_classes=0),
  894. 'convnextv2_tiny.fcmae': _cfgv2(
  895. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_tiny_1k_224_fcmae.pt",
  896. hf_hub_id='timm/',
  897. num_classes=0),
  898. 'convnextv2_base.fcmae': _cfgv2(
  899. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_base_1k_224_fcmae.pt",
  900. hf_hub_id='timm/',
  901. num_classes=0),
  902. 'convnextv2_large.fcmae': _cfgv2(
  903. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt",
  904. hf_hub_id='timm/',
  905. num_classes=0),
  906. 'convnextv2_huge.fcmae': _cfgv2(
  907. url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt",
  908. hf_hub_id='timm/',
  909. num_classes=0),
  910. 'convnextv2_small.untrained': _cfg(),
  911. # CLIP weights, fine-tuned on in1k or in12k + in1k
  912. 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k': _cfg(
  913. hf_hub_id='timm/',
  914. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  915. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
  916. 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k_384': _cfg(
  917. hf_hub_id='timm/',
  918. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  919. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  920. 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_320': _cfg(
  921. hf_hub_id='timm/',
  922. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  923. input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0),
  924. 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384': _cfg(
  925. hf_hub_id='timm/',
  926. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  927. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  928. 'convnext_base.clip_laion2b_augreg_ft_in1k': _cfg(
  929. hf_hub_id='timm/',
  930. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  931. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
  932. 'convnext_base.clip_laiona_augreg_ft_in1k_384': _cfg(
  933. hf_hub_id='timm/',
  934. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  935. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
  936. 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k': _cfg(
  937. hf_hub_id='timm/',
  938. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  939. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0
  940. ),
  941. 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k_384': _cfg(
  942. hf_hub_id='timm/',
  943. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  944. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'
  945. ),
  946. 'convnext_xxlarge.clip_laion2b_soup_ft_in1k': _cfg(
  947. hf_hub_id='timm/',
  948. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  949. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
  950. 'convnext_base.clip_laion2b_augreg_ft_in12k': _cfg(
  951. hf_hub_id='timm/',
  952. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
  953. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
  954. 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_320': _cfg(
  955. hf_hub_id='timm/',
  956. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
  957. input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0),
  958. 'convnext_large_mlp.clip_laion2b_augreg_ft_in12k_384': _cfg(
  959. hf_hub_id='timm/',
  960. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
  961. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  962. 'convnext_large_mlp.clip_laion2b_soup_ft_in12k_384': _cfg(
  963. hf_hub_id='timm/',
  964. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
  965. input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  966. 'convnext_xxlarge.clip_laion2b_soup_ft_in12k': _cfg(
  967. hf_hub_id='timm/',
  968. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821,
  969. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
  970. # CLIP original image tower weights
  971. 'convnext_base.clip_laion2b': _cfg(
  972. hf_hub_id='timm/',
  973. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  974. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
  975. 'convnext_base.clip_laion2b_augreg': _cfg(
  976. hf_hub_id='timm/',
  977. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  978. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
  979. 'convnext_base.clip_laiona': _cfg(
  980. hf_hub_id='timm/',
  981. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  982. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
  983. 'convnext_base.clip_laiona_320': _cfg(
  984. hf_hub_id='timm/',
  985. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  986. input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
  987. 'convnext_base.clip_laiona_augreg_320': _cfg(
  988. hf_hub_id='timm/',
  989. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  990. input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
  991. 'convnext_large_mlp.clip_laion2b_augreg': _cfg(
  992. hf_hub_id='timm/',
  993. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  994. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
  995. 'convnext_large_mlp.clip_laion2b_ft_320': _cfg(
  996. hf_hub_id='timm/',
  997. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  998. input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
  999. 'convnext_large_mlp.clip_laion2b_ft_soup_320': _cfg(
  1000. hf_hub_id='timm/',
  1001. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1002. input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
  1003. 'convnext_xxlarge.clip_laion2b_soup': _cfg(
  1004. hf_hub_id='timm/',
  1005. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1006. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
  1007. 'convnext_xxlarge.clip_laion2b_rewind': _cfg(
  1008. hf_hub_id='timm/',
  1009. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1010. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=1024),
  1011. # NOTE dinov3 convnext weights are under a specific license, and downstream outputs must be shared with this
  1012. # https://ai.meta.com/resources/models-and-libraries/dinov3-license/
  1013. 'convnext_tiny.dinov3_lvd1689m': _cfg(
  1014. hf_hub_id='timm/',
  1015. crop_pct=1.0,
  1016. num_classes=0,
  1017. license='dinov3-license',
  1018. ),
  1019. 'convnext_small.dinov3_lvd1689m': _cfg(
  1020. hf_hub_id='timm/',
  1021. crop_pct=1.0,
  1022. num_classes=0,
  1023. license='dinov3-license',
  1024. ),
  1025. 'convnext_base.dinov3_lvd1689m': _cfg(
  1026. hf_hub_id='timm/',
  1027. crop_pct=1.0,
  1028. num_classes=0,
  1029. license='dinov3-license',
  1030. ),
  1031. 'convnext_large.dinov3_lvd1689m': _cfg(
  1032. hf_hub_id='timm/',
  1033. crop_pct=1.0,
  1034. num_classes=0,
  1035. license='dinov3-license',
  1036. ),
  1037. "test_convnext.r160_in1k": _cfg(
  1038. hf_hub_id='timm/',
  1039. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1040. input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.95),
  1041. "test_convnext2.r160_in1k": _cfg(
  1042. hf_hub_id='timm/',
  1043. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1044. input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.95),
  1045. "test_convnext3.r160_in1k": _cfg(
  1046. hf_hub_id='timm/',
  1047. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1048. input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.95),
  1049. })
  1050. @register_model
  1051. def convnext_zepto_rms(pretrained=False, **kwargs) -> ConvNeXt:
  1052. # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
  1053. model_args = dict(depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='simplenorm')
  1054. model = _create_convnext('convnext_zepto_rms', pretrained=pretrained, **dict(model_args, **kwargs))
  1055. return model
  1056. @register_model
  1057. def convnext_zepto_rms_ols(pretrained=False, **kwargs) -> ConvNeXt:
  1058. # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
  1059. model_args = dict(
  1060. depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='simplenorm', stem_type='overlap_act')
  1061. model = _create_convnext('convnext_zepto_rms_ols', pretrained=pretrained, **dict(model_args, **kwargs))
  1062. return model
  1063. @register_model
  1064. def convnext_atto(pretrained=False, **kwargs) -> ConvNeXt:
  1065. # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
  1066. model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True)
  1067. model = _create_convnext('convnext_atto', pretrained=pretrained, **dict(model_args, **kwargs))
  1068. return model
  1069. @register_model
  1070. def convnext_atto_ols(pretrained=False, **kwargs) -> ConvNeXt:
  1071. # timm femto variant with overlapping 3x3 conv stem, wider than non-ols femto above, current param count 3.7M
  1072. model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, stem_type='overlap_tiered')
  1073. model = _create_convnext('convnext_atto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
  1074. return model
  1075. @register_model
  1076. def convnext_atto_rms(pretrained=False, **kwargs) -> ConvNeXt:
  1077. # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
  1078. model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, norm_layer='rmsnorm2d')
  1079. model = _create_convnext('convnext_atto_rms', pretrained=pretrained, **dict(model_args, **kwargs))
  1080. return model
  1081. @register_model
  1082. def convnext_femto(pretrained=False, **kwargs) -> ConvNeXt:
  1083. # timm femto variant
  1084. model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True)
  1085. model = _create_convnext('convnext_femto', pretrained=pretrained, **dict(model_args, **kwargs))
  1086. return model
  1087. @register_model
  1088. def convnext_femto_ols(pretrained=False, **kwargs) -> ConvNeXt:
  1089. # timm femto variant
  1090. model_args = dict(depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), conv_mlp=True, stem_type='overlap_tiered')
  1091. model = _create_convnext('convnext_femto_ols', pretrained=pretrained, **dict(model_args, **kwargs))
  1092. return model
  1093. @register_model
  1094. def convnext_pico(pretrained=False, **kwargs) -> ConvNeXt:
  1095. # timm pico variant
  1096. model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True)
  1097. model = _create_convnext('convnext_pico', pretrained=pretrained, **dict(model_args, **kwargs))
  1098. return model
  1099. @register_model
  1100. def convnext_pico_ols(pretrained=False, **kwargs) -> ConvNeXt:
  1101. # timm nano variant with overlapping 3x3 conv stem
  1102. model_args = dict(depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), conv_mlp=True, stem_type='overlap_tiered')
  1103. model = _create_convnext('convnext_pico_ols', pretrained=pretrained, **dict(model_args, **kwargs))
  1104. return model
  1105. @register_model
  1106. def convnext_nano(pretrained=False, **kwargs) -> ConvNeXt:
  1107. # timm nano variant with standard stem and head
  1108. model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True)
  1109. model = _create_convnext('convnext_nano', pretrained=pretrained, **dict(model_args, **kwargs))
  1110. return model
  1111. @register_model
  1112. def convnext_nano_ols(pretrained=False, **kwargs) -> ConvNeXt:
  1113. # experimental nano variant with overlapping conv stem
  1114. model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), conv_mlp=True, stem_type='overlap')
  1115. model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **dict(model_args, **kwargs))
  1116. return model
  1117. @register_model
  1118. def convnext_tiny_hnf(pretrained=False, **kwargs) -> ConvNeXt:
  1119. # experimental tiny variant with norm before pooling in head (head norm first)
  1120. model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True)
  1121. model = _create_convnext('convnext_tiny_hnf', pretrained=pretrained, **dict(model_args, **kwargs))
  1122. return model
  1123. @register_model
  1124. def convnext_tiny(pretrained=False, **kwargs) -> ConvNeXt:
  1125. model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768))
  1126. model = _create_convnext('convnext_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
  1127. return model
  1128. @register_model
  1129. def convnext_small(pretrained=False, **kwargs) -> ConvNeXt:
  1130. model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
  1131. model = _create_convnext('convnext_small', pretrained=pretrained, **dict(model_args, **kwargs))
  1132. return model
  1133. @register_model
  1134. def convnext_base(pretrained=False, **kwargs) -> ConvNeXt:
  1135. model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
  1136. model = _create_convnext('convnext_base', pretrained=pretrained, **dict(model_args, **kwargs))
  1137. return model
  1138. @register_model
  1139. def convnext_large(pretrained=False, **kwargs) -> ConvNeXt:
  1140. model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
  1141. model = _create_convnext('convnext_large', pretrained=pretrained, **dict(model_args, **kwargs))
  1142. return model
  1143. @register_model
  1144. def convnext_large_mlp(pretrained=False, **kwargs) -> ConvNeXt:
  1145. model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536)
  1146. model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **dict(model_args, **kwargs))
  1147. return model
  1148. @register_model
  1149. def convnext_xlarge(pretrained=False, **kwargs) -> ConvNeXt:
  1150. model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])
  1151. model = _create_convnext('convnext_xlarge', pretrained=pretrained, **dict(model_args, **kwargs))
  1152. return model
  1153. @register_model
  1154. def convnext_xxlarge(pretrained=False, **kwargs) -> ConvNeXt:
  1155. model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], norm_eps=kwargs.pop('norm_eps', 1e-5))
  1156. model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **dict(model_args, **kwargs))
  1157. return model
  1158. @register_model
  1159. def convnextv2_atto(pretrained=False, **kwargs) -> ConvNeXt:
  1160. # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
  1161. model_args = dict(
  1162. depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True)
  1163. model = _create_convnext('convnextv2_atto', pretrained=pretrained, **dict(model_args, **kwargs))
  1164. return model
  1165. @register_model
  1166. def convnextv2_femto(pretrained=False, **kwargs) -> ConvNeXt:
  1167. # timm femto variant
  1168. model_args = dict(
  1169. depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True)
  1170. model = _create_convnext('convnextv2_femto', pretrained=pretrained, **dict(model_args, **kwargs))
  1171. return model
  1172. @register_model
  1173. def convnextv2_pico(pretrained=False, **kwargs) -> ConvNeXt:
  1174. # timm pico variant
  1175. model_args = dict(
  1176. depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True)
  1177. model = _create_convnext('convnextv2_pico', pretrained=pretrained, **dict(model_args, **kwargs))
  1178. return model
  1179. @register_model
  1180. def convnextv2_nano(pretrained=False, **kwargs) -> ConvNeXt:
  1181. # timm nano variant with standard stem and head
  1182. model_args = dict(
  1183. depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True)
  1184. model = _create_convnext('convnextv2_nano', pretrained=pretrained, **dict(model_args, **kwargs))
  1185. return model
  1186. @register_model
  1187. def convnextv2_tiny(pretrained=False, **kwargs) -> ConvNeXt:
  1188. model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None)
  1189. model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
  1190. return model
  1191. @register_model
  1192. def convnextv2_small(pretrained=False, **kwargs) -> ConvNeXt:
  1193. model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None)
  1194. model = _create_convnext('convnextv2_small', pretrained=pretrained, **dict(model_args, **kwargs))
  1195. return model
  1196. @register_model
  1197. def convnextv2_base(pretrained=False, **kwargs) -> ConvNeXt:
  1198. model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None)
  1199. model = _create_convnext('convnextv2_base', pretrained=pretrained, **dict(model_args, **kwargs))
  1200. return model
  1201. @register_model
  1202. def convnextv2_large(pretrained=False, **kwargs) -> ConvNeXt:
  1203. model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None)
  1204. model = _create_convnext('convnextv2_large', pretrained=pretrained, **dict(model_args, **kwargs))
  1205. return model
  1206. @register_model
  1207. def convnextv2_huge(pretrained=False, **kwargs) -> ConvNeXt:
  1208. model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None)
  1209. model = _create_convnext('convnextv2_huge', pretrained=pretrained, **dict(model_args, **kwargs))
  1210. return model
  1211. @register_model
  1212. def test_convnext(pretrained=False, **kwargs) -> ConvNeXt:
  1213. model_args = dict(depths=[1, 2, 4, 2], dims=[24, 32, 48, 64], norm_eps=kwargs.pop('norm_eps', 1e-5), act_layer='gelu_tanh')
  1214. model = _create_convnext('test_convnext', pretrained=pretrained, **dict(model_args, **kwargs))
  1215. return model
  1216. @register_model
  1217. def test_convnext2(pretrained=False, **kwargs) -> ConvNeXt:
  1218. model_args = dict(depths=[1, 1, 1, 1], dims=[32, 64, 96, 128], norm_eps=kwargs.pop('norm_eps', 1e-5), act_layer='gelu_tanh')
  1219. model = _create_convnext('test_convnext2', pretrained=pretrained, **dict(model_args, **kwargs))
  1220. return model
  1221. @register_model
  1222. def test_convnext3(pretrained=False, **kwargs) -> ConvNeXt:
  1223. model_args = dict(
  1224. depths=[1, 1, 1, 1], dims=[32, 64, 96, 128], norm_eps=kwargs.pop('norm_eps', 1e-5), kernel_sizes=(7, 5, 5, 3), act_layer='silu')
  1225. model = _create_convnext('test_convnext3', pretrained=pretrained, **dict(model_args, **kwargs))
  1226. return model
  1227. register_model_deprecations(__name__, {
  1228. 'convnext_tiny_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k',
  1229. 'convnext_small_in22ft1k': 'convnext_small.fb_in22k_ft_in1k',
  1230. 'convnext_base_in22ft1k': 'convnext_base.fb_in22k_ft_in1k',
  1231. 'convnext_large_in22ft1k': 'convnext_large.fb_in22k_ft_in1k',
  1232. 'convnext_xlarge_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k',
  1233. 'convnext_tiny_384_in22ft1k': 'convnext_tiny.fb_in22k_ft_in1k_384',
  1234. 'convnext_small_384_in22ft1k': 'convnext_small.fb_in22k_ft_in1k_384',
  1235. 'convnext_base_384_in22ft1k': 'convnext_base.fb_in22k_ft_in1k_384',
  1236. 'convnext_large_384_in22ft1k': 'convnext_large.fb_in22k_ft_in1k_384',
  1237. 'convnext_xlarge_384_in22ft1k': 'convnext_xlarge.fb_in22k_ft_in1k_384',
  1238. 'convnext_tiny_in22k': 'convnext_tiny.fb_in22k',
  1239. 'convnext_small_in22k': 'convnext_small.fb_in22k',
  1240. 'convnext_base_in22k': 'convnext_base.fb_in22k',
  1241. 'convnext_large_in22k': 'convnext_large.fb_in22k',
  1242. 'convnext_xlarge_in22k': 'convnext_xlarge.fb_in22k',
  1243. })