efficientnet.py 123 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971
  1. """ The EfficientNet Family in PyTorch
  2. An implementation of EfficienNet that covers variety of related models with efficient architectures:
  3. * EfficientNet-V2
  4. - `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
  5. * EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent weight ports)
  6. - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946
  7. - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971
  8. - Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665
  9. - Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252
  10. * MixNet (Small, Medium, and Large)
  11. - MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595
  12. * MNasNet B1, A1 (SE), Small
  13. - MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626
  14. * FBNet-C
  15. - FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443
  16. * Single-Path NAS Pixel1
  17. - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877
  18. * TinyNet
  19. - Model Rubik's Cube: Twisting Resolution, Depth and Width for TinyNets - https://arxiv.org/abs/2010.14819
  20. - Definitions & weights borrowed from https://github.com/huawei-noah/CV-Backbones/tree/master/tinynet_pytorch
  21. * And likely more...
  22. The majority of the above models (EfficientNet*, MixNet, MnasNet) and original weights were made available
  23. by Mingxing Tan, Quoc Le, and other members of their Google Brain team. Thanks for consistently releasing
  24. the models and weights open source!
  25. Hacked together by / Copyright 2019, Ross Wightman
  26. """
  27. from functools import partial
  28. from typing import Callable, Dict, List, Optional, Tuple, Union
  29. import torch
  30. import torch.nn as nn
  31. import torch.nn.functional as F
  32. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
  33. from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, LayerType, \
  34. GroupNormAct, LayerNormAct2d, EvoNorm2dS0
  35. from ._builder import build_model_with_cfg, pretrained_cfg_for_features
  36. from ._efficientnet_blocks import SqueezeExcite
  37. from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
  38. round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
  39. from ._features import FeatureInfo, FeatureHooks, feature_take_indices
  40. from ._manipulate import checkpoint_seq, checkpoint
  41. from ._registry import generate_default_cfgs, register_model, register_model_deprecations
  42. __all__ = ['EfficientNet', 'EfficientNetFeatures']
  43. class EfficientNet(nn.Module):
  44. """EfficientNet model architecture.
  45. A flexible and performant PyTorch implementation of efficient network architectures, including:
  46. * EfficientNet-V2 Small, Medium, Large, XL & B0-B3
  47. * EfficientNet B0-B8, L2
  48. * EfficientNet-EdgeTPU
  49. * EfficientNet-CondConv
  50. * MixNet S, M, L, XL
  51. * MnasNet A1, B1, and small
  52. * MobileNet-V2
  53. * FBNet C
  54. * Single-Path NAS Pixel1
  55. * TinyNet
  56. References:
  57. - EfficientNet: https://arxiv.org/abs/1905.11946
  58. - EfficientNetV2: https://arxiv.org/abs/2104.00298
  59. - MixNet: https://arxiv.org/abs/1907.09595
  60. - MnasNet: https://arxiv.org/abs/1807.11626
  61. """
  62. def __init__(
  63. self,
  64. block_args: BlockArgs,
  65. num_classes: int = 1000,
  66. num_features: int = 1280,
  67. in_chans: int = 3,
  68. stem_size: int = 32,
  69. stem_kernel_size: int = 3,
  70. fix_stem: bool = False,
  71. output_stride: int = 32,
  72. pad_type: str = '',
  73. act_layer: Optional[LayerType] = None,
  74. norm_layer: Optional[LayerType] = None,
  75. aa_layer: Optional[LayerType] = None,
  76. se_layer: Optional[LayerType] = None,
  77. round_chs_fn: Callable = round_channels,
  78. drop_rate: float = 0.,
  79. drop_path_rate: float = 0.,
  80. global_pool: str = 'avg',
  81. device=None,
  82. dtype=None,
  83. ) -> None:
  84. """Initialize EfficientNet model.
  85. Args:
  86. block_args: Arguments for building blocks.
  87. num_classes: Number of classifier classes.
  88. num_features: Number of features for penultimate layer.
  89. in_chans: Number of input channels.
  90. stem_size: Number of output channels in stem.
  91. stem_kernel_size: Kernel size for stem convolution.
  92. fix_stem: If True, don't scale stem channels.
  93. output_stride: Output stride of network.
  94. pad_type: Padding type.
  95. act_layer: Activation layer class.
  96. norm_layer: Normalization layer class.
  97. aa_layer: Anti-aliasing layer class.
  98. se_layer: Squeeze-and-excitation layer class.
  99. round_chs_fn: Channel rounding function.
  100. drop_rate: Dropout rate for classifier.
  101. drop_path_rate: Drop path rate for stochastic depth.
  102. global_pool: Global pooling type.
  103. """
  104. super().__init__()
  105. dd = {'device': device, 'dtype': dtype}
  106. act_layer = act_layer or nn.ReLU
  107. norm_layer = norm_layer or nn.BatchNorm2d
  108. norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
  109. se_layer = se_layer or SqueezeExcite
  110. self.num_classes = num_classes
  111. self.drop_rate = drop_rate
  112. self.grad_checkpointing = False
  113. # Stem
  114. if not fix_stem:
  115. stem_size = round_chs_fn(stem_size)
  116. self.conv_stem = create_conv2d(in_chans, stem_size, stem_kernel_size, stride=2, padding=pad_type, **dd)
  117. self.bn1 = norm_act_layer(stem_size, inplace=True, **dd)
  118. # Middle stages (IR/ER/DS Blocks)
  119. builder = EfficientNetBuilder(
  120. output_stride=output_stride,
  121. pad_type=pad_type,
  122. round_chs_fn=round_chs_fn,
  123. act_layer=act_layer,
  124. norm_layer=norm_layer,
  125. aa_layer=aa_layer,
  126. se_layer=se_layer,
  127. drop_path_rate=drop_path_rate,
  128. **dd,
  129. )
  130. self.blocks = nn.Sequential(*builder(stem_size, block_args))
  131. self.feature_info = builder.features
  132. self.stage_ends = [f['stage'] for f in self.feature_info]
  133. head_chs = builder.in_chs
  134. # Head + Pooling
  135. if num_features > 0:
  136. self.conv_head = create_conv2d(head_chs, num_features, 1, padding=pad_type, **dd)
  137. self.bn2 = norm_act_layer(num_features, inplace=True, **dd)
  138. self.num_features = self.head_hidden_size = num_features
  139. else:
  140. self.conv_head = nn.Identity()
  141. self.bn2 = nn.Identity()
  142. self.num_features = self.head_hidden_size = head_chs
  143. self.global_pool, self.classifier = create_classifier(
  144. self.num_features,
  145. self.num_classes,
  146. pool_type=global_pool,
  147. **dd,
  148. )
  149. efficientnet_init_weights(self)
  150. def as_sequential(self) -> nn.Sequential:
  151. """Convert model to sequential for feature extraction."""
  152. layers = [self.conv_stem, self.bn1]
  153. layers.extend(self.blocks)
  154. layers.extend([self.conv_head, self.bn2, self.global_pool])
  155. layers.extend([nn.Dropout(self.drop_rate), self.classifier])
  156. return nn.Sequential(*layers)
  157. @torch.jit.ignore
  158. def group_matcher(self, coarse: bool = False) -> Dict[str, Union[str, List]]:
  159. """Create regex patterns for parameter groups.
  160. Args:
  161. coarse: Use coarse (stage-level) grouping.
  162. Returns:
  163. Dictionary mapping group names to regex patterns.
  164. """
  165. return dict(
  166. stem=r'^conv_stem|bn1',
  167. blocks=[
  168. (r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None),
  169. (r'conv_head|bn2', (99999,))
  170. ]
  171. )
  172. @torch.jit.ignore
  173. def set_grad_checkpointing(self, enable: bool = True) -> None:
  174. """Enable or disable gradient checkpointing.
  175. Args:
  176. enable: Whether to enable gradient checkpointing.
  177. """
  178. self.grad_checkpointing = enable
  179. @torch.jit.ignore
  180. def get_classifier(self) -> nn.Module:
  181. """Get the classifier module."""
  182. return self.classifier
  183. def reset_classifier(self, num_classes: int, global_pool: str = 'avg') -> None:
  184. """Reset the classifier head.
  185. Args:
  186. num_classes: Number of classes for new classifier.
  187. global_pool: Global pooling type.
  188. """
  189. self.num_classes = num_classes
  190. self.global_pool, self.classifier = create_classifier(
  191. self.num_features, self.num_classes, pool_type=global_pool)
  192. def forward_intermediates(
  193. self,
  194. x: torch.Tensor,
  195. indices: Optional[Union[int, List[int]]] = None,
  196. norm: bool = False,
  197. stop_early: bool = False,
  198. output_fmt: str = 'NCHW',
  199. intermediates_only: bool = False,
  200. extra_blocks: bool = False,
  201. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  202. """Forward features that returns intermediates.
  203. Args:
  204. x: Input image tensor.
  205. indices: Take last n blocks if int, all if None, select matching indices if sequence.
  206. norm: Apply norm layer to compatible intermediates.
  207. stop_early: Stop iterating over blocks when last desired intermediate hit.
  208. output_fmt: Shape of intermediate feature outputs.
  209. intermediates_only: Only return intermediate features.
  210. extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info.
  211. Returns:
  212. List of intermediate features or tuple of (final features, intermediates).
  213. """
  214. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  215. intermediates = []
  216. if extra_blocks:
  217. take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
  218. else:
  219. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  220. take_indices = [self.stage_ends[i] for i in take_indices]
  221. max_index = self.stage_ends[max_index]
  222. # forward pass
  223. feat_idx = 0 # stem is index 0
  224. x = self.conv_stem(x)
  225. x = self.bn1(x)
  226. if feat_idx in take_indices:
  227. intermediates.append(x)
  228. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  229. blocks = self.blocks
  230. else:
  231. blocks = self.blocks[:max_index]
  232. for feat_idx, blk in enumerate(blocks, start=1):
  233. if self.grad_checkpointing and not torch.jit.is_scripting():
  234. x = checkpoint_seq(blk, x)
  235. else:
  236. x = blk(x)
  237. if feat_idx in take_indices:
  238. intermediates.append(x)
  239. if intermediates_only:
  240. return intermediates
  241. if feat_idx == self.stage_ends[-1]:
  242. x = self.conv_head(x)
  243. x = self.bn2(x)
  244. return x, intermediates
  245. def prune_intermediate_layers(
  246. self,
  247. indices: Union[int, List[int]] = 1,
  248. prune_norm: bool = False,
  249. prune_head: bool = True,
  250. extra_blocks: bool = False,
  251. ) -> List[int]:
  252. """Prune layers not required for specified intermediates.
  253. Args:
  254. indices: Indices of intermediate layers to keep.
  255. prune_norm: Whether to prune normalization layers.
  256. prune_head: Whether to prune the classifier head.
  257. extra_blocks: Include all blocks in indexing.
  258. Returns:
  259. List of indices that were kept.
  260. """
  261. if extra_blocks:
  262. take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
  263. else:
  264. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  265. max_index = self.stage_ends[max_index]
  266. self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0
  267. if prune_norm or max_index < len(self.blocks):
  268. self.conv_head = nn.Identity()
  269. self.bn2 = nn.Identity()
  270. if prune_head:
  271. self.reset_classifier(0, '')
  272. return take_indices
  273. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  274. """Forward pass through feature extraction layers."""
  275. x = self.conv_stem(x)
  276. x = self.bn1(x)
  277. if self.grad_checkpointing and not torch.jit.is_scripting():
  278. x = checkpoint_seq(self.blocks, x, flatten=True)
  279. else:
  280. x = self.blocks(x)
  281. x = self.conv_head(x)
  282. x = self.bn2(x)
  283. return x
  284. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  285. """Forward pass through classifier head.
  286. Args:
  287. x: Feature tensor.
  288. pre_logits: Return features before final classifier.
  289. Returns:
  290. Output tensor.
  291. """
  292. x = self.global_pool(x)
  293. if self.drop_rate > 0.:
  294. x = F.dropout(x, p=self.drop_rate, training=self.training)
  295. return x if pre_logits else self.classifier(x)
  296. def forward(self, x: torch.Tensor) -> torch.Tensor:
  297. """Forward pass."""
  298. x = self.forward_features(x)
  299. x = self.forward_head(x)
  300. return x
  301. class EfficientNetFeatures(nn.Module):
  302. """ EfficientNet Feature Extractor
  303. A work-in-progress feature extraction module for EfficientNet, to use as a backbone for segmentation
  304. and object detection models.
  305. """
  306. def __init__(
  307. self,
  308. block_args: BlockArgs,
  309. out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
  310. feature_location: str = 'bottleneck',
  311. in_chans: int = 3,
  312. stem_size: int = 32,
  313. stem_kernel_size: int = 3,
  314. fix_stem: bool = False,
  315. output_stride: int = 32,
  316. pad_type: str = '',
  317. act_layer: Optional[LayerType] = None,
  318. norm_layer: Optional[LayerType] = None,
  319. aa_layer: Optional[LayerType] = None,
  320. se_layer: Optional[LayerType] = None,
  321. round_chs_fn: Callable = round_channels,
  322. drop_rate: float = 0.,
  323. drop_path_rate: float = 0.,
  324. device=None,
  325. dtype=None,
  326. ):
  327. super().__init__()
  328. dd = {'device': device, 'dtype': dtype}
  329. act_layer = act_layer or nn.ReLU
  330. norm_layer = norm_layer or nn.BatchNorm2d
  331. norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
  332. se_layer = se_layer or SqueezeExcite
  333. self.drop_rate = drop_rate
  334. self.grad_checkpointing = False
  335. # Stem
  336. if not fix_stem:
  337. stem_size = round_chs_fn(stem_size)
  338. self.conv_stem = create_conv2d(in_chans, stem_size, stem_kernel_size, stride=2, padding=pad_type, **dd)
  339. self.bn1 = norm_act_layer(stem_size, inplace=True, **dd)
  340. # Middle stages (IR/ER/DS Blocks)
  341. builder = EfficientNetBuilder(
  342. output_stride=output_stride,
  343. pad_type=pad_type,
  344. round_chs_fn=round_chs_fn,
  345. act_layer=act_layer,
  346. norm_layer=norm_layer,
  347. aa_layer=aa_layer,
  348. se_layer=se_layer,
  349. drop_path_rate=drop_path_rate,
  350. feature_location=feature_location,
  351. **dd,
  352. )
  353. self.blocks = nn.Sequential(*builder(stem_size, block_args))
  354. self.feature_info = FeatureInfo(builder.features, out_indices)
  355. self._stage_out_idx = {f['stage']: f['index'] for f in self.feature_info.get_dicts()}
  356. efficientnet_init_weights(self)
  357. # Register feature extraction hooks with FeatureHooks helper
  358. self.feature_hooks = None
  359. if feature_location != 'bottleneck':
  360. hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
  361. self.feature_hooks = FeatureHooks(hooks, self.named_modules())
  362. @torch.jit.ignore
  363. def set_grad_checkpointing(self, enable: bool = True) -> None:
  364. """Enable or disable gradient checkpointing.
  365. Args:
  366. enable: Whether to enable gradient checkpointing.
  367. """
  368. self.grad_checkpointing = enable
  369. def forward(self, x) -> List[torch.Tensor]:
  370. x = self.conv_stem(x)
  371. x = self.bn1(x)
  372. if self.feature_hooks is None:
  373. features = []
  374. if 0 in self._stage_out_idx:
  375. features.append(x) # add stem out
  376. for i, b in enumerate(self.blocks):
  377. if self.grad_checkpointing and not torch.jit.is_scripting():
  378. x = checkpoint(b, x)
  379. else:
  380. x = b(x)
  381. if i + 1 in self._stage_out_idx:
  382. features.append(x)
  383. return features
  384. else:
  385. self.blocks(x)
  386. out = self.feature_hooks.get_output(x.device)
  387. return list(out.values())
  388. def _create_effnet(variant, pretrained=False, **kwargs):
  389. features_mode = ''
  390. model_cls = EfficientNet
  391. kwargs_filter = None
  392. if kwargs.pop('features_only', False):
  393. if 'feature_cfg' in kwargs or 'feature_cls' in kwargs:
  394. features_mode = 'cfg'
  395. else:
  396. kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool')
  397. model_cls = EfficientNetFeatures
  398. features_mode = 'cls'
  399. pretrained_strict = kwargs.pop('pretrained_strict', True)
  400. model = build_model_with_cfg(
  401. model_cls,
  402. variant,
  403. pretrained,
  404. features_only=features_mode == 'cfg',
  405. pretrained_strict=pretrained_strict and features_mode != 'cls',
  406. kwargs_filter=kwargs_filter,
  407. **kwargs,
  408. )
  409. if features_mode == 'cls':
  410. model.pretrained_cfg = model.default_cfg = pretrained_cfg_for_features(model.pretrained_cfg)
  411. return model
  412. def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
  413. """Creates a mnasnet-a1 model.
  414. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
  415. Paper: https://arxiv.org/pdf/1807.11626.pdf.
  416. Args:
  417. channel_multiplier: multiplier to number of channels per layer.
  418. """
  419. arch_def = [
  420. # stage 0, 112x112 in
  421. ['ds_r1_k3_s1_e1_c16_noskip'],
  422. # stage 1, 112x112 in
  423. ['ir_r2_k3_s2_e6_c24'],
  424. # stage 2, 56x56 in
  425. ['ir_r3_k5_s2_e3_c40_se0.25'],
  426. # stage 3, 28x28 in
  427. ['ir_r4_k3_s2_e6_c80'],
  428. # stage 4, 14x14in
  429. ['ir_r2_k3_s1_e6_c112_se0.25'],
  430. # stage 5, 14x14in
  431. ['ir_r3_k5_s2_e6_c160_se0.25'],
  432. # stage 6, 7x7 in
  433. ['ir_r1_k3_s1_e6_c320'],
  434. ]
  435. model_kwargs = dict(
  436. block_args=decode_arch_def(arch_def),
  437. stem_size=32,
  438. round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
  439. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  440. **kwargs
  441. )
  442. model = _create_effnet(variant, pretrained, **model_kwargs)
  443. return model
  444. def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
  445. """Creates a mnasnet-b1 model.
  446. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
  447. Paper: https://arxiv.org/pdf/1807.11626.pdf.
  448. Args:
  449. channel_multiplier: multiplier to number of channels per layer.
  450. """
  451. arch_def = [
  452. # stage 0, 112x112 in
  453. ['ds_r1_k3_s1_c16_noskip'],
  454. # stage 1, 112x112 in
  455. ['ir_r3_k3_s2_e3_c24'],
  456. # stage 2, 56x56 in
  457. ['ir_r3_k5_s2_e3_c40'],
  458. # stage 3, 28x28 in
  459. ['ir_r3_k5_s2_e6_c80'],
  460. # stage 4, 14x14in
  461. ['ir_r2_k3_s1_e6_c96'],
  462. # stage 5, 14x14in
  463. ['ir_r4_k5_s2_e6_c192'],
  464. # stage 6, 7x7 in
  465. ['ir_r1_k3_s1_e6_c320_noskip']
  466. ]
  467. model_kwargs = dict(
  468. block_args=decode_arch_def(arch_def),
  469. stem_size=32,
  470. round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
  471. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  472. **kwargs
  473. )
  474. model = _create_effnet(variant, pretrained, **model_kwargs)
  475. return model
  476. def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
  477. """Creates a mnasnet-b1 model.
  478. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
  479. Paper: https://arxiv.org/pdf/1807.11626.pdf.
  480. Args:
  481. channel_multiplier: multiplier to number of channels per layer.
  482. """
  483. arch_def = [
  484. ['ds_r1_k3_s1_c8'],
  485. ['ir_r1_k3_s2_e3_c16'],
  486. ['ir_r2_k3_s2_e6_c16'],
  487. ['ir_r4_k5_s2_e6_c32_se0.25'],
  488. ['ir_r3_k3_s1_e6_c32_se0.25'],
  489. ['ir_r3_k5_s2_e6_c88_se0.25'],
  490. ['ir_r1_k3_s1_e6_c144']
  491. ]
  492. model_kwargs = dict(
  493. block_args=decode_arch_def(arch_def),
  494. stem_size=8,
  495. round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
  496. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  497. **kwargs
  498. )
  499. model = _create_effnet(variant, pretrained, **model_kwargs)
  500. return model
  501. def _gen_mobilenet_v1(
  502. variant, channel_multiplier=1.0, depth_multiplier=1.0,
  503. group_size=None, fix_stem_head=False, head_conv=False, pretrained=False, **kwargs
  504. ):
  505. """
  506. Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
  507. Paper: https://arxiv.org/abs/1801.04381
  508. """
  509. arch_def = [
  510. ['dsa_r1_k3_s1_c64'],
  511. ['dsa_r2_k3_s2_c128'],
  512. ['dsa_r2_k3_s2_c256'],
  513. ['dsa_r6_k3_s2_c512'],
  514. ['dsa_r2_k3_s2_c1024'],
  515. ]
  516. round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
  517. head_features = (1024 if fix_stem_head else max(1024, round_chs_fn(1024))) if head_conv else 0
  518. model_kwargs = dict(
  519. block_args=decode_arch_def(
  520. arch_def,
  521. depth_multiplier=depth_multiplier,
  522. fix_first_last=fix_stem_head,
  523. group_size=group_size,
  524. ),
  525. num_features=head_features,
  526. stem_size=32,
  527. fix_stem=fix_stem_head,
  528. round_chs_fn=round_chs_fn,
  529. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  530. act_layer=resolve_act_layer(kwargs, 'relu6'),
  531. **kwargs
  532. )
  533. model = _create_effnet(variant, pretrained, **model_kwargs)
  534. return model
  535. def _gen_mobilenet_v2(
  536. variant, channel_multiplier=1.0, depth_multiplier=1.0,
  537. group_size=None, fix_stem_head=False, pretrained=False, **kwargs
  538. ):
  539. """ Generate MobileNet-V2 network
  540. Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
  541. Paper: https://arxiv.org/abs/1801.04381
  542. """
  543. arch_def = [
  544. ['ds_r1_k3_s1_c16'],
  545. ['ir_r2_k3_s2_e6_c24'],
  546. ['ir_r3_k3_s2_e6_c32'],
  547. ['ir_r4_k3_s2_e6_c64'],
  548. ['ir_r3_k3_s1_e6_c96'],
  549. ['ir_r3_k3_s2_e6_c160'],
  550. ['ir_r1_k3_s1_e6_c320'],
  551. ]
  552. round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
  553. model_kwargs = dict(
  554. block_args=decode_arch_def(
  555. arch_def,
  556. depth_multiplier=depth_multiplier,
  557. fix_first_last=fix_stem_head,
  558. group_size=group_size,
  559. ),
  560. num_features=1280 if fix_stem_head else max(1280, round_chs_fn(1280)),
  561. stem_size=32,
  562. fix_stem=fix_stem_head,
  563. round_chs_fn=round_chs_fn,
  564. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  565. act_layer=resolve_act_layer(kwargs, 'relu6'),
  566. **kwargs
  567. )
  568. model = _create_effnet(variant, pretrained, **model_kwargs)
  569. return model
  570. def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
  571. """ FBNet-C
  572. Paper: https://arxiv.org/abs/1812.03443
  573. Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py
  574. NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper,
  575. it was used to confirm some building block details
  576. """
  577. arch_def = [
  578. ['ir_r1_k3_s1_e1_c16'],
  579. ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'],
  580. ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'],
  581. ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'],
  582. ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'],
  583. ['ir_r4_k5_s2_e6_c184'],
  584. ['ir_r1_k3_s1_e6_c352'],
  585. ]
  586. model_kwargs = dict(
  587. block_args=decode_arch_def(arch_def),
  588. stem_size=16,
  589. num_features=1984, # paper suggests this, but is not 100% clear
  590. round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
  591. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  592. **kwargs
  593. )
  594. model = _create_effnet(variant, pretrained, **model_kwargs)
  595. return model
  596. def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
  597. """Creates the Single-Path NAS model from search targeted for Pixel1 phone.
  598. Paper: https://arxiv.org/abs/1904.02877
  599. Args:
  600. channel_multiplier: multiplier to number of channels per layer.
  601. """
  602. arch_def = [
  603. # stage 0, 112x112 in
  604. ['ds_r1_k3_s1_c16_noskip'],
  605. # stage 1, 112x112 in
  606. ['ir_r3_k3_s2_e3_c24'],
  607. # stage 2, 56x56 in
  608. ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'],
  609. # stage 3, 28x28 in
  610. ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'],
  611. # stage 4, 14x14in
  612. ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'],
  613. # stage 5, 14x14in
  614. ['ir_r4_k5_s2_e6_c192'],
  615. # stage 6, 7x7 in
  616. ['ir_r1_k3_s1_e6_c320_noskip']
  617. ]
  618. model_kwargs = dict(
  619. block_args=decode_arch_def(arch_def),
  620. stem_size=32,
  621. round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
  622. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  623. **kwargs
  624. )
  625. model = _create_effnet(variant, pretrained, **model_kwargs)
  626. return model
  627. def _gen_efficientnet(
  628. variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8,
  629. group_size=None, pretrained=False, **kwargs
  630. ):
  631. """Creates an EfficientNet model.
  632. Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
  633. Paper: https://arxiv.org/abs/1905.11946
  634. EfficientNet params
  635. name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
  636. 'efficientnet-b0': (1.0, 1.0, 224, 0.2),
  637. 'efficientnet-b1': (1.0, 1.1, 240, 0.2),
  638. 'efficientnet-b2': (1.1, 1.2, 260, 0.3),
  639. 'efficientnet-b3': (1.2, 1.4, 300, 0.3),
  640. 'efficientnet-b4': (1.4, 1.8, 380, 0.4),
  641. 'efficientnet-b5': (1.6, 2.2, 456, 0.4),
  642. 'efficientnet-b6': (1.8, 2.6, 528, 0.5),
  643. 'efficientnet-b7': (2.0, 3.1, 600, 0.5),
  644. 'efficientnet-b8': (2.2, 3.6, 672, 0.5),
  645. 'efficientnet-l2': (4.3, 5.3, 800, 0.5),
  646. Args:
  647. channel_multiplier: multiplier to number of channels per layer
  648. depth_multiplier: multiplier to number of repeats per stage
  649. """
  650. arch_def = [
  651. ['ds_r1_k3_s1_e1_c16_se0.25'],
  652. ['ir_r2_k3_s2_e6_c24_se0.25'],
  653. ['ir_r2_k5_s2_e6_c40_se0.25'],
  654. ['ir_r3_k3_s2_e6_c80_se0.25'],
  655. ['ir_r3_k5_s1_e6_c112_se0.25'],
  656. ['ir_r4_k5_s2_e6_c192_se0.25'],
  657. ['ir_r1_k3_s1_e6_c320_se0.25'],
  658. ]
  659. round_chs_fn = partial(round_channels, multiplier=channel_multiplier, divisor=channel_divisor)
  660. model_kwargs = dict(
  661. block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
  662. num_features=round_chs_fn(1280),
  663. stem_size=32,
  664. round_chs_fn=round_chs_fn,
  665. act_layer=resolve_act_layer(kwargs, 'swish'),
  666. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  667. **kwargs,
  668. )
  669. model = _create_effnet(variant, pretrained, **model_kwargs)
  670. return model
  671. def _gen_efficientnet_edge(
  672. variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
  673. ):
  674. """ Creates an EfficientNet-EdgeTPU model
  675. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu
  676. """
  677. arch_def = [
  678. # NOTE `fc` is present to override a mismatch between stem channels and in chs not
  679. # present in other models
  680. ['er_r1_k3_s1_e4_c24_fc24_noskip'],
  681. ['er_r2_k3_s2_e8_c32'],
  682. ['er_r4_k3_s2_e8_c48'],
  683. ['ir_r5_k5_s2_e8_c96'],
  684. ['ir_r4_k5_s1_e8_c144'],
  685. ['ir_r2_k5_s2_e8_c192'],
  686. ]
  687. round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
  688. model_kwargs = dict(
  689. block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
  690. num_features=round_chs_fn(1280),
  691. stem_size=32,
  692. round_chs_fn=round_chs_fn,
  693. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  694. act_layer=resolve_act_layer(kwargs, 'relu'),
  695. **kwargs,
  696. )
  697. model = _create_effnet(variant, pretrained, **model_kwargs)
  698. return model
  699. def _gen_efficientnet_condconv(
  700. variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs
  701. ):
  702. """Creates an EfficientNet-CondConv model.
  703. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
  704. """
  705. arch_def = [
  706. ['ds_r1_k3_s1_e1_c16_se0.25'],
  707. ['ir_r2_k3_s2_e6_c24_se0.25'],
  708. ['ir_r2_k5_s2_e6_c40_se0.25'],
  709. ['ir_r3_k3_s2_e6_c80_se0.25'],
  710. ['ir_r3_k5_s1_e6_c112_se0.25_cc4'],
  711. ['ir_r4_k5_s2_e6_c192_se0.25_cc4'],
  712. ['ir_r1_k3_s1_e6_c320_se0.25_cc4'],
  713. ]
  714. # NOTE unlike official impl, this one uses `cc<x>` option where x is the base number of experts for each stage and
  715. # the expert_multiplier increases that on a per-model basis as with depth/channel multipliers
  716. round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
  717. model_kwargs = dict(
  718. block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier),
  719. num_features=round_chs_fn(1280),
  720. stem_size=32,
  721. round_chs_fn=round_chs_fn,
  722. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  723. act_layer=resolve_act_layer(kwargs, 'swish'),
  724. **kwargs,
  725. )
  726. model = _create_effnet(variant, pretrained, **model_kwargs)
  727. return model
  728. def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
  729. """Creates an EfficientNet-Lite model.
  730. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite
  731. Paper: https://arxiv.org/abs/1905.11946
  732. EfficientNet params
  733. name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
  734. 'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
  735. 'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
  736. 'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
  737. 'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
  738. 'efficientnet-lite4': (1.4, 1.8, 300, 0.3),
  739. Args:
  740. channel_multiplier: multiplier to number of channels per layer
  741. depth_multiplier: multiplier to number of repeats per stage
  742. """
  743. arch_def = [
  744. ['ds_r1_k3_s1_e1_c16'],
  745. ['ir_r2_k3_s2_e6_c24'],
  746. ['ir_r2_k5_s2_e6_c40'],
  747. ['ir_r3_k3_s2_e6_c80'],
  748. ['ir_r3_k5_s1_e6_c112'],
  749. ['ir_r4_k5_s2_e6_c192'],
  750. ['ir_r1_k3_s1_e6_c320'],
  751. ]
  752. model_kwargs = dict(
  753. block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True),
  754. num_features=1280,
  755. stem_size=32,
  756. fix_stem=True,
  757. round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
  758. act_layer=resolve_act_layer(kwargs, 'relu6'),
  759. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  760. **kwargs,
  761. )
  762. model = _create_effnet(variant, pretrained, **model_kwargs)
  763. return model
  764. def _gen_efficientnetv2_base(
  765. variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
  766. ):
  767. """ Creates an EfficientNet-V2 base model
  768. Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
  769. Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
  770. """
  771. arch_def = [
  772. ['cn_r1_k3_s1_e1_c16_skip'],
  773. ['er_r2_k3_s2_e4_c32'],
  774. ['er_r2_k3_s2_e4_c48'],
  775. ['ir_r3_k3_s2_e4_c96_se0.25'],
  776. ['ir_r5_k3_s1_e6_c112_se0.25'],
  777. ['ir_r8_k3_s2_e6_c192_se0.25'],
  778. ]
  779. round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.)
  780. model_kwargs = dict(
  781. block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
  782. num_features=round_chs_fn(1280),
  783. stem_size=32,
  784. round_chs_fn=round_chs_fn,
  785. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  786. act_layer=resolve_act_layer(kwargs, 'silu'),
  787. **kwargs,
  788. )
  789. model = _create_effnet(variant, pretrained, **model_kwargs)
  790. return model
  791. def _gen_efficientnetv2_s(
  792. variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, rw=False, pretrained=False, **kwargs
  793. ):
  794. """ Creates an EfficientNet-V2 Small model
  795. Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
  796. Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
  797. NOTE: `rw` flag sets up 'small' variant to behave like my initial v2 small model,
  798. before ref the impl was released.
  799. """
  800. arch_def = [
  801. ['cn_r2_k3_s1_e1_c24_skip'],
  802. ['er_r4_k3_s2_e4_c48'],
  803. ['er_r4_k3_s2_e4_c64'],
  804. ['ir_r6_k3_s2_e4_c128_se0.25'],
  805. ['ir_r9_k3_s1_e6_c160_se0.25'],
  806. ['ir_r15_k3_s2_e6_c256_se0.25'],
  807. ]
  808. num_features = 1280
  809. if rw:
  810. # my original variant, based on paper figure differs from the official release
  811. arch_def[0] = ['er_r2_k3_s1_e1_c24']
  812. arch_def[-1] = ['ir_r15_k3_s2_e6_c272_se0.25']
  813. num_features = 1792
  814. round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
  815. model_kwargs = dict(
  816. block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
  817. num_features=round_chs_fn(num_features),
  818. stem_size=24,
  819. round_chs_fn=round_chs_fn,
  820. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  821. act_layer=resolve_act_layer(kwargs, 'silu'),
  822. **kwargs,
  823. )
  824. model = _create_effnet(variant, pretrained, **model_kwargs)
  825. return model
  826. def _gen_efficientnetv2_m(
  827. variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
  828. ):
  829. """ Creates an EfficientNet-V2 Medium model
  830. Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
  831. Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
  832. """
  833. arch_def = [
  834. ['cn_r3_k3_s1_e1_c24_skip'],
  835. ['er_r5_k3_s2_e4_c48'],
  836. ['er_r5_k3_s2_e4_c80'],
  837. ['ir_r7_k3_s2_e4_c160_se0.25'],
  838. ['ir_r14_k3_s1_e6_c176_se0.25'],
  839. ['ir_r18_k3_s2_e6_c304_se0.25'],
  840. ['ir_r5_k3_s1_e6_c512_se0.25'],
  841. ]
  842. model_kwargs = dict(
  843. block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
  844. num_features=1280,
  845. stem_size=24,
  846. round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
  847. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  848. act_layer=resolve_act_layer(kwargs, 'silu'),
  849. **kwargs,
  850. )
  851. model = _create_effnet(variant, pretrained, **model_kwargs)
  852. return model
  853. def _gen_efficientnetv2_l(
  854. variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
  855. ):
  856. """ Creates an EfficientNet-V2 Large model
  857. Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
  858. Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
  859. """
  860. arch_def = [
  861. ['cn_r4_k3_s1_e1_c32_skip'],
  862. ['er_r7_k3_s2_e4_c64'],
  863. ['er_r7_k3_s2_e4_c96'],
  864. ['ir_r10_k3_s2_e4_c192_se0.25'],
  865. ['ir_r19_k3_s1_e6_c224_se0.25'],
  866. ['ir_r25_k3_s2_e6_c384_se0.25'],
  867. ['ir_r7_k3_s1_e6_c640_se0.25'],
  868. ]
  869. model_kwargs = dict(
  870. block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
  871. num_features=1280,
  872. stem_size=32,
  873. round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
  874. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  875. act_layer=resolve_act_layer(kwargs, 'silu'),
  876. **kwargs,
  877. )
  878. model = _create_effnet(variant, pretrained, **model_kwargs)
  879. return model
  880. def _gen_efficientnetv2_xl(
  881. variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
  882. ):
  883. """ Creates an EfficientNet-V2 Xtra-Large model
  884. Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
  885. Paper: `EfficientNetV2: Smaller Models and Faster Training` - https://arxiv.org/abs/2104.00298
  886. """
  887. arch_def = [
  888. ['cn_r4_k3_s1_e1_c32_skip'],
  889. ['er_r8_k3_s2_e4_c64'],
  890. ['er_r8_k3_s2_e4_c96'],
  891. ['ir_r16_k3_s2_e4_c192_se0.25'],
  892. ['ir_r24_k3_s1_e6_c256_se0.25'],
  893. ['ir_r32_k3_s2_e6_c512_se0.25'],
  894. ['ir_r8_k3_s1_e6_c640_se0.25'],
  895. ]
  896. model_kwargs = dict(
  897. block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
  898. num_features=1280,
  899. stem_size=32,
  900. round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
  901. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  902. act_layer=resolve_act_layer(kwargs, 'silu'),
  903. **kwargs,
  904. )
  905. model = _create_effnet(variant, pretrained, **model_kwargs)
  906. return model
  907. def _gen_efficientnet_x(
  908. variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8,
  909. group_size=None, version=1, pretrained=False, **kwargs
  910. ):
  911. """Creates an EfficientNet model.
  912. Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
  913. Paper: https://arxiv.org/abs/1905.11946
  914. EfficientNet params
  915. name: (channel_multiplier, depth_multiplier, resolution, dropout_rate)
  916. 'efficientnet-x-b0': (1.0, 1.0, 224, 0.2),
  917. 'efficientnet-x-b1': (1.0, 1.1, 240, 0.2),
  918. 'efficientnet-x-b2': (1.1, 1.2, 260, 0.3),
  919. 'efficientnet-x-b3': (1.2, 1.4, 300, 0.3),
  920. 'efficientnet-x-b4': (1.4, 1.8, 380, 0.4),
  921. 'efficientnet-x-b5': (1.6, 2.2, 456, 0.4),
  922. 'efficientnet-x-b6': (1.8, 2.6, 528, 0.5),
  923. 'efficientnet-x-b7': (2.0, 3.1, 600, 0.5),
  924. 'efficientnet-x-b8': (2.2, 3.6, 672, 0.5),
  925. 'efficientnet-l2': (4.3, 5.3, 800, 0.5),
  926. Args:
  927. channel_multiplier: multiplier to number of channels per layer
  928. depth_multiplier: multiplier to number of repeats per stage
  929. """
  930. """
  931. if version == 1:
  932. blocks_args = [
  933. 'r1_k3_s11_e1_i32_o16_se0.25_d1_a0',
  934. 'r2_k3_s22_e6_i16_o24_se0.25_f1_d2_a1',
  935. 'r2_k5_s22_e6_i24_o40_se0.25_f1_a1',
  936. 'r3_k3_s22_e6_i40_o80_se0.25_a0',
  937. 'r3_k5_s11_e6_i80_o112_se0.25_a0',
  938. 'r4_k5_s22_e6_i112_o192_se0.25_a0',
  939. 'r1_k3_s11_e6_i192_o320_se0.25_a0',
  940. ]
  941. elif version == 2:
  942. blocks_args = [
  943. 'r1_k3_s11_e1_i32_o16_se0.25_d1_a0',
  944. 'r2_k3_s22_e4_i16_o24_se0.25_f1_d2_a1',
  945. 'r2_k5_s22_e4_i24_o40_se0.25_f1_a1',
  946. 'r3_k3_s22_e4_i40_o80_se0.25_a0',
  947. 'r3_k5_s11_e6_i80_o112_se0.25_a0',
  948. 'r4_k5_s22_e6_i112_o192_se0.25_a0',
  949. 'r1_k3_s11_e6_i192_o320_se0.25_a0',
  950. ]
  951. """
  952. if version == 1:
  953. arch_def = [
  954. ['ds_r1_k3_s1_e1_c16_se0.25_d1'],
  955. ['er_r2_k3_s2_e6_c24_se0.25_nre'],
  956. ['er_r2_k5_s2_e6_c40_se0.25_nre'],
  957. ['ir_r3_k3_s2_e6_c80_se0.25'],
  958. ['ir_r3_k5_s1_e6_c112_se0.25'],
  959. ['ir_r4_k5_s2_e6_c192_se0.25'],
  960. ['ir_r1_k3_s1_e6_c320_se0.25'],
  961. ]
  962. else:
  963. arch_def = [
  964. ['ds_r1_k3_s1_e1_c16_se0.25_d1'],
  965. ['er_r2_k3_s2_e4_c24_se0.25_nre'],
  966. ['er_r2_k5_s2_e4_c40_se0.25_nre'],
  967. ['ir_r3_k3_s2_e4_c80_se0.25'],
  968. ['ir_r3_k5_s1_e6_c112_se0.25'],
  969. ['ir_r4_k5_s2_e6_c192_se0.25'],
  970. ['ir_r1_k3_s1_e6_c320_se0.25'],
  971. ]
  972. round_chs_fn = partial(round_channels, multiplier=channel_multiplier, divisor=channel_divisor)
  973. model_kwargs = dict(
  974. block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
  975. num_features=round_chs_fn(1280),
  976. stem_size=32,
  977. round_chs_fn=round_chs_fn,
  978. act_layer=resolve_act_layer(kwargs, 'silu'),
  979. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  980. **kwargs,
  981. )
  982. model = _create_effnet(variant, pretrained, **model_kwargs)
  983. return model
  984. def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
  985. """Creates a MixNet Small model.
  986. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
  987. Paper: https://arxiv.org/abs/1907.09595
  988. """
  989. arch_def = [
  990. # stage 0, 112x112 in
  991. ['ds_r1_k3_s1_e1_c16'], # relu
  992. # stage 1, 112x112 in
  993. ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu
  994. # stage 2, 56x56 in
  995. ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish
  996. # stage 3, 28x28 in
  997. ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish
  998. # stage 4, 14x14in
  999. ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish
  1000. # stage 5, 14x14in
  1001. ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish
  1002. # 7x7
  1003. ]
  1004. model_kwargs = dict(
  1005. block_args=decode_arch_def(arch_def),
  1006. num_features=1536,
  1007. stem_size=16,
  1008. round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
  1009. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  1010. **kwargs
  1011. )
  1012. model = _create_effnet(variant, pretrained, **model_kwargs)
  1013. return model
  1014. def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
  1015. """Creates a MixNet Medium-Large model.
  1016. Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
  1017. Paper: https://arxiv.org/abs/1907.09595
  1018. """
  1019. arch_def = [
  1020. # stage 0, 112x112 in
  1021. ['ds_r1_k3_s1_e1_c24'], # relu
  1022. # stage 1, 112x112 in
  1023. ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu
  1024. # stage 2, 56x56 in
  1025. ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish
  1026. # stage 3, 28x28 in
  1027. ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish
  1028. # stage 4, 14x14in
  1029. ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish
  1030. # stage 5, 14x14in
  1031. ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish
  1032. # 7x7
  1033. ]
  1034. model_kwargs = dict(
  1035. block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
  1036. num_features=1536,
  1037. stem_size=24,
  1038. round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
  1039. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  1040. **kwargs
  1041. )
  1042. model = _create_effnet(variant, pretrained, **model_kwargs)
  1043. return model
  1044. def _gen_tinynet(variant, model_width=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
  1045. """Creates a TinyNet model.
  1046. """
  1047. arch_def = [
  1048. ['ds_r1_k3_s1_e1_c16_se0.25'], ['ir_r2_k3_s2_e6_c24_se0.25'],
  1049. ['ir_r2_k5_s2_e6_c40_se0.25'], ['ir_r3_k3_s2_e6_c80_se0.25'],
  1050. ['ir_r3_k5_s1_e6_c112_se0.25'], ['ir_r4_k5_s2_e6_c192_se0.25'],
  1051. ['ir_r1_k3_s1_e6_c320_se0.25'],
  1052. ]
  1053. model_kwargs = dict(
  1054. block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'),
  1055. num_features=max(1280, round_channels(1280, model_width, 8, None)),
  1056. stem_size=32,
  1057. fix_stem=True,
  1058. round_chs_fn=partial(round_channels, multiplier=model_width),
  1059. act_layer=resolve_act_layer(kwargs, 'swish'),
  1060. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  1061. **kwargs,
  1062. )
  1063. model = _create_effnet(variant, pretrained, **model_kwargs)
  1064. return model
  1065. def _gen_mobilenet_edgetpu(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
  1066. """
  1067. Based on definitions in: https://github.com/tensorflow/models/tree/d2427a562f401c9af118e47af2f030a0a5599f55/official/projects/edgetpu/vision
  1068. """
  1069. if 'edgetpu_v2' in variant:
  1070. stem_size = 64
  1071. stem_kernel_size = 5
  1072. group_size = 64
  1073. num_features = 1280
  1074. act_layer = resolve_act_layer(kwargs, 'relu')
  1075. def _arch_def(chs: List[int], group_size: int):
  1076. return [
  1077. # stage 0, 112x112 in
  1078. [f'cn_r1_k1_s1_c{chs[0]}'], # NOTE with expansion==1, official impl block ends just 1x1 pwl
  1079. # stage 1, 112x112 in
  1080. [f'er_r1_k3_s2_e8_c{chs[1]}', f'er_r1_k3_s1_e4_gs{group_size}_c{chs[1]}'],
  1081. # stage 2, 56x56 in
  1082. [
  1083. f'er_r1_k3_s2_e8_c{chs[2]}',
  1084. f'er_r1_k3_s1_e4_gs{group_size}_c{chs[2]}',
  1085. f'er_r1_k3_s1_e4_c{chs[2]}',
  1086. f'er_r1_k3_s1_e4_gs{group_size}_c{chs[2]}',
  1087. ],
  1088. # stage 3, 28x28 in
  1089. [f'er_r1_k3_s2_e8_c{chs[3]}', f'ir_r3_k3_s1_e4_c{chs[3]}'],
  1090. # stage 4, 14x14in
  1091. [f'ir_r1_k3_s1_e8_c{chs[4]}', f'ir_r3_k3_s1_e4_c{chs[4]}'],
  1092. # stage 5, 14x14in
  1093. [f'ir_r1_k3_s2_e8_c{chs[5]}', f'ir_r3_k3_s1_e4_c{chs[5]}'],
  1094. # stage 6, 7x7 in
  1095. [f'ir_r1_k3_s1_e8_c{chs[6]}'],
  1096. ]
  1097. if 'edgetpu_v2_xs' in variant:
  1098. stem_size = 32
  1099. stem_kernel_size = 3
  1100. channels = [16, 32, 48, 96, 144, 160, 192]
  1101. elif 'edgetpu_v2_s' in variant:
  1102. channels = [24, 48, 64, 128, 160, 192, 256]
  1103. elif 'edgetpu_v2_m' in variant:
  1104. channels = [32, 64, 80, 160, 192, 240, 320]
  1105. num_features = 1344
  1106. elif 'edgetpu_v2_l' in variant:
  1107. stem_kernel_size = 7
  1108. group_size = 128
  1109. channels = [32, 64, 96, 192, 240, 256, 384]
  1110. num_features = 1408
  1111. else:
  1112. assert False
  1113. arch_def = _arch_def(channels, group_size)
  1114. else:
  1115. # v1
  1116. stem_size = 32
  1117. stem_kernel_size = 3
  1118. num_features = 1280
  1119. act_layer = resolve_act_layer(kwargs, 'relu')
  1120. arch_def = [
  1121. # stage 0, 112x112 in
  1122. ['cn_r1_k1_s1_c16'],
  1123. # stage 1, 112x112 in
  1124. ['er_r1_k3_s2_e8_c32', 'er_r3_k3_s1_e4_c32'],
  1125. # stage 2, 56x56 in
  1126. ['er_r1_k3_s2_e8_c48', 'er_r3_k3_s1_e4_c48'],
  1127. # stage 3, 28x28 in
  1128. ['ir_r1_k3_s2_e8_c96', 'ir_r3_k3_s1_e4_c96'],
  1129. # stage 4, 14x14in
  1130. ['ir_r1_k3_s1_e8_c96_noskip', 'ir_r3_k3_s1_e4_c96'],
  1131. # stage 5, 14x14in
  1132. ['ir_r1_k5_s2_e8_c160', 'ir_r3_k5_s1_e4_c160'],
  1133. # stage 6, 7x7 in
  1134. ['ir_r1_k3_s1_e8_c192'],
  1135. ]
  1136. model_kwargs = dict(
  1137. block_args=decode_arch_def(arch_def, depth_multiplier),
  1138. num_features=num_features,
  1139. stem_size=stem_size,
  1140. stem_kernel_size=stem_kernel_size,
  1141. round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
  1142. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  1143. act_layer=act_layer,
  1144. **kwargs,
  1145. )
  1146. model = _create_effnet(variant, pretrained, **model_kwargs)
  1147. return model
  1148. def _gen_test_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
  1149. """ Minimal test EfficientNet generator.
  1150. """
  1151. arch_def = [
  1152. ['cn_r1_k3_s1_e1_c16_skip'],
  1153. ['er_r1_k3_s2_e4_c24'],
  1154. ['er_r1_k3_s2_e4_c32'],
  1155. ['ir_r1_k3_s2_e4_c48_se0.25'],
  1156. ['ir_r1_k3_s2_e4_c64_se0.25'],
  1157. ]
  1158. round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.)
  1159. model_kwargs = dict(
  1160. block_args=decode_arch_def(arch_def, depth_multiplier),
  1161. num_features=round_chs_fn(256),
  1162. stem_size=24,
  1163. round_chs_fn=round_chs_fn,
  1164. norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
  1165. act_layer=resolve_act_layer(kwargs, 'silu'),
  1166. **kwargs,
  1167. )
  1168. model = _create_effnet(variant, pretrained, **model_kwargs)
  1169. return model
  1170. def _cfg(url='', **kwargs):
  1171. return {
  1172. 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  1173. 'crop_pct': 0.875, 'interpolation': 'bicubic',
  1174. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  1175. 'first_conv': 'conv_stem', 'classifier': 'classifier',
  1176. 'license': 'apache-2.0', **kwargs
  1177. }
  1178. default_cfgs = generate_default_cfgs({
  1179. 'mnasnet_050.untrained': _cfg(),
  1180. 'mnasnet_075.untrained': _cfg(),
  1181. 'mnasnet_100.rmsp_in1k': _cfg(
  1182. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth',
  1183. hf_hub_id='timm/'),
  1184. 'mnasnet_140.untrained': _cfg(),
  1185. 'semnasnet_050.untrained': _cfg(),
  1186. 'semnasnet_075.rmsp_in1k': _cfg(
  1187. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/semnasnet_075-18710866.pth',
  1188. hf_hub_id='timm/'),
  1189. 'semnasnet_100.rmsp_in1k': _cfg(
  1190. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth',
  1191. hf_hub_id='timm/'),
  1192. 'semnasnet_140.untrained': _cfg(),
  1193. 'mnasnet_small.lamb_in1k': _cfg(
  1194. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_small_lamb-aff75073.pth',
  1195. hf_hub_id='timm/'),
  1196. 'mobilenetv1_100.ra4_e3600_r224_in1k': _cfg(
  1197. hf_hub_id='timm/',
  1198. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1199. test_input_size=(3, 256, 256), test_crop_pct=0.95,
  1200. ),
  1201. 'mobilenetv1_100h.ra4_e3600_r224_in1k': _cfg(
  1202. hf_hub_id='timm/',
  1203. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1204. test_input_size=(3, 256, 256), test_crop_pct=0.95,
  1205. ),
  1206. 'mobilenetv1_125.ra4_e3600_r224_in1k': _cfg(
  1207. hf_hub_id='timm/',
  1208. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1209. crop_pct=0.9, test_input_size=(3, 256, 256), test_crop_pct=1.0,
  1210. ),
  1211. 'mobilenetv2_035.untrained': _cfg(),
  1212. 'mobilenetv2_050.lamb_in1k': _cfg(
  1213. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_050-3d30d450.pth',
  1214. hf_hub_id='timm/',
  1215. interpolation='bicubic',
  1216. ),
  1217. 'mobilenetv2_075.untrained': _cfg(),
  1218. 'mobilenetv2_100.ra_in1k': _cfg(
  1219. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth',
  1220. hf_hub_id='timm/'),
  1221. 'mobilenetv2_110d.ra_in1k': _cfg(
  1222. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth',
  1223. hf_hub_id='timm/'),
  1224. 'mobilenetv2_120d.ra_in1k': _cfg(
  1225. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth',
  1226. hf_hub_id='timm/'),
  1227. 'mobilenetv2_140.ra_in1k': _cfg(
  1228. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth',
  1229. hf_hub_id='timm/'),
  1230. 'fbnetc_100.rmsp_in1k': _cfg(
  1231. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth',
  1232. hf_hub_id='timm/',
  1233. interpolation='bilinear'),
  1234. 'spnasnet_100.rmsp_in1k': _cfg(
  1235. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth',
  1236. hf_hub_id='timm/',
  1237. interpolation='bilinear'),
  1238. # NOTE experimenting with alternate attention
  1239. 'efficientnet_b0.ra_in1k': _cfg(
  1240. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth',
  1241. hf_hub_id='timm/'),
  1242. 'efficientnet_b0.ra4_e3600_r224_in1k': _cfg(
  1243. hf_hub_id='timm/',
  1244. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1245. crop_pct=0.9, test_input_size=(3, 256, 256), test_crop_pct=1.0),
  1246. 'efficientnet_b1.ra4_e3600_r240_in1k': _cfg(
  1247. hf_hub_id='timm/',
  1248. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1249. input_size=(3, 240, 240), crop_pct=0.9, pool_size=(8, 8),
  1250. test_input_size=(3, 288, 288), test_crop_pct=1.0),
  1251. 'efficientnet_b1.ft_in1k': _cfg(
  1252. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
  1253. hf_hub_id='timm/',
  1254. test_input_size=(3, 256, 256), test_crop_pct=1.0),
  1255. 'efficientnet_b2.ra_in1k': _cfg(
  1256. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth',
  1257. hf_hub_id='timm/',
  1258. input_size=(3, 256, 256), pool_size=(8, 8), test_input_size=(3, 288, 288), test_crop_pct=1.0),
  1259. 'efficientnet_b3.ra2_in1k': _cfg(
  1260. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth',
  1261. hf_hub_id='timm/',
  1262. input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), test_crop_pct=1.0),
  1263. 'efficientnet_b4.ra2_in1k': _cfg(
  1264. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b4_ra2_320-7eb33cd5.pth',
  1265. hf_hub_id='timm/',
  1266. input_size=(3, 320, 320), pool_size=(10, 10), test_input_size=(3, 384, 384), test_crop_pct=1.0),
  1267. 'efficientnet_b5.sw_in12k_ft_in1k': _cfg(
  1268. hf_hub_id='timm/',
  1269. input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0, crop_mode='squash'),
  1270. 'efficientnet_b5.sw_in12k': _cfg(
  1271. hf_hub_id='timm/',
  1272. input_size=(3, 416, 416), pool_size=(13, 13), crop_pct=0.95, num_classes=11821),
  1273. 'efficientnet_b6.untrained': _cfg(
  1274. url='', input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
  1275. 'efficientnet_b7.untrained': _cfg(
  1276. url='', input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
  1277. 'efficientnet_b8.untrained': _cfg(
  1278. url='', input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
  1279. 'efficientnet_l2.untrained': _cfg(
  1280. url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961),
  1281. # FIXME experimental
  1282. 'efficientnet_b0_gn.untrained': _cfg(),
  1283. 'efficientnet_b0_g8_gn.untrained': _cfg(),
  1284. 'efficientnet_b0_g16_evos.untrained': _cfg(),
  1285. 'efficientnet_b3_gn.untrained': _cfg(
  1286. input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0),
  1287. 'efficientnet_b3_g8_gn.untrained': _cfg(
  1288. input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0),
  1289. 'efficientnet_blur_b0.untrained': _cfg(),
  1290. 'efficientnet_h_b5.sw_r448_e450_in1k': _cfg(
  1291. hf_hub_id='timm/',
  1292. input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0,
  1293. crop_mode='squash', test_input_size=(3, 576, 576)),
  1294. 'efficientnet_x_b3.untrained': _cfg(
  1295. url='', input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=0.95),
  1296. 'efficientnet_x_b5.sw_r448_e450_in1k': _cfg(
  1297. hf_hub_id='timm/',
  1298. input_size=(3, 448, 448), pool_size=(14, 14), crop_pct=1.0,
  1299. crop_mode='squash', test_input_size=(3, 576, 576)),
  1300. 'efficientnet_es.ra_in1k': _cfg(
  1301. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth',
  1302. hf_hub_id='timm/'),
  1303. 'efficientnet_em.ra2_in1k': _cfg(
  1304. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_em_ra2-66250f76.pth',
  1305. hf_hub_id='timm/',
  1306. input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
  1307. 'efficientnet_el.ra_in1k': _cfg(
  1308. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_el-3b455510.pth',
  1309. hf_hub_id='timm/',
  1310. input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
  1311. 'efficientnet_es_pruned.in1k': _cfg(
  1312. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_pruned75-1b7248cf.pth',
  1313. hf_hub_id='timm/'),
  1314. 'efficientnet_el_pruned.in1k': _cfg(
  1315. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_el_pruned70-ef2a2ccf.pth',
  1316. hf_hub_id='timm/',
  1317. input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
  1318. 'efficientnet_cc_b0_4e.untrained': _cfg(),
  1319. 'efficientnet_cc_b0_8e.untrained': _cfg(),
  1320. 'efficientnet_cc_b1_8e.untrained': _cfg(input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
  1321. 'efficientnet_lite0.ra_in1k': _cfg(
  1322. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth',
  1323. hf_hub_id='timm/'),
  1324. 'efficientnet_lite1.untrained': _cfg(
  1325. input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
  1326. 'efficientnet_lite2.untrained': _cfg(
  1327. input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
  1328. 'efficientnet_lite3.untrained': _cfg(
  1329. input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
  1330. 'efficientnet_lite4.untrained': _cfg(
  1331. input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
  1332. 'efficientnet_b1_pruned.in1k': _cfg(
  1333. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/effnetb1_pruned-bea43a3a.pth',
  1334. hf_hub_id='timm/',
  1335. input_size=(3, 240, 240), pool_size=(8, 8),
  1336. crop_pct=0.882, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
  1337. 'efficientnet_b2_pruned.in1k': _cfg(
  1338. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/effnetb2_pruned-08c1b27c.pth',
  1339. hf_hub_id='timm/',
  1340. input_size=(3, 260, 260), pool_size=(9, 9),
  1341. crop_pct=0.890, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
  1342. 'efficientnet_b3_pruned.in1k': _cfg(
  1343. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/effnetb3_pruned-59ecf72d.pth',
  1344. hf_hub_id='timm/',
  1345. input_size=(3, 300, 300), pool_size=(10, 10),
  1346. crop_pct=0.904, mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
  1347. 'efficientnetv2_rw_t.ra2_in1k': _cfg(
  1348. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnetv2_t_agc-3620981a.pth',
  1349. hf_hub_id='timm/',
  1350. input_size=(3, 224, 224), test_input_size=(3, 288, 288), pool_size=(7, 7), crop_pct=1.0),
  1351. 'gc_efficientnetv2_rw_t.agc_in1k': _cfg(
  1352. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gc_efficientnetv2_rw_t_agc-927a0bde.pth',
  1353. hf_hub_id='timm/',
  1354. input_size=(3, 224, 224), test_input_size=(3, 288, 288), pool_size=(7, 7), crop_pct=1.0),
  1355. 'efficientnetv2_rw_s.ra2_in1k': _cfg(
  1356. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth',
  1357. hf_hub_id='timm/',
  1358. input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0),
  1359. 'efficientnetv2_rw_m.agc_in1k': _cfg(
  1360. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnetv2_rw_m_agc-3d90cb1e.pth',
  1361. hf_hub_id='timm/',
  1362. input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0),
  1363. 'efficientnetv2_s.untrained': _cfg(
  1364. input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0),
  1365. 'efficientnetv2_m.untrained': _cfg(
  1366. input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0),
  1367. 'efficientnetv2_l.untrained': _cfg(
  1368. input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
  1369. 'efficientnetv2_xl.untrained': _cfg(
  1370. input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0),
  1371. 'tf_efficientnet_b0.ns_jft_in1k': _cfg(
  1372. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth',
  1373. hf_hub_id='timm/',
  1374. input_size=(3, 224, 224)),
  1375. 'tf_efficientnet_b1.ns_jft_in1k': _cfg(
  1376. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth',
  1377. hf_hub_id='timm/',
  1378. input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
  1379. 'tf_efficientnet_b2.ns_jft_in1k': _cfg(
  1380. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth',
  1381. hf_hub_id='timm/',
  1382. input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
  1383. 'tf_efficientnet_b3.ns_jft_in1k': _cfg(
  1384. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth',
  1385. hf_hub_id='timm/',
  1386. input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
  1387. 'tf_efficientnet_b4.ns_jft_in1k': _cfg(
  1388. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth',
  1389. hf_hub_id='timm/',
  1390. input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
  1391. 'tf_efficientnet_b5.ns_jft_in1k': _cfg(
  1392. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth',
  1393. hf_hub_id='timm/',
  1394. input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
  1395. 'tf_efficientnet_b6.ns_jft_in1k': _cfg(
  1396. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth',
  1397. hf_hub_id='timm/',
  1398. input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
  1399. 'tf_efficientnet_b7.ns_jft_in1k': _cfg(
  1400. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth',
  1401. hf_hub_id='timm/',
  1402. input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
  1403. 'tf_efficientnet_l2.ns_jft_in1k_475': _cfg(
  1404. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth',
  1405. hf_hub_id='timm/',
  1406. input_size=(3, 475, 475), pool_size=(15, 15), crop_pct=0.936),
  1407. 'tf_efficientnet_l2.ns_jft_in1k': _cfg(
  1408. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth',
  1409. hf_hub_id='timm/',
  1410. input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.96),
  1411. 'tf_efficientnet_b0.ap_in1k': _cfg(
  1412. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth',
  1413. hf_hub_id='timm/',
  1414. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, input_size=(3, 224, 224)),
  1415. 'tf_efficientnet_b1.ap_in1k': _cfg(
  1416. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth',
  1417. hf_hub_id='timm/',
  1418. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1419. input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
  1420. 'tf_efficientnet_b2.ap_in1k': _cfg(
  1421. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth',
  1422. hf_hub_id='timm/',
  1423. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1424. input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
  1425. 'tf_efficientnet_b3.ap_in1k': _cfg(
  1426. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth',
  1427. hf_hub_id='timm/',
  1428. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1429. input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
  1430. 'tf_efficientnet_b4.ap_in1k': _cfg(
  1431. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth',
  1432. hf_hub_id='timm/',
  1433. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1434. input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
  1435. 'tf_efficientnet_b5.ap_in1k': _cfg(
  1436. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth',
  1437. hf_hub_id='timm/',
  1438. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1439. input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
  1440. 'tf_efficientnet_b6.ap_in1k': _cfg(
  1441. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth',
  1442. hf_hub_id='timm/',
  1443. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1444. input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
  1445. 'tf_efficientnet_b7.ap_in1k': _cfg(
  1446. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth',
  1447. hf_hub_id='timm/',
  1448. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1449. input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
  1450. 'tf_efficientnet_b8.ap_in1k': _cfg(
  1451. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth',
  1452. hf_hub_id='timm/',
  1453. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1454. input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
  1455. 'tf_efficientnet_b5.ra_in1k': _cfg(
  1456. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth',
  1457. hf_hub_id='timm/',
  1458. input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
  1459. 'tf_efficientnet_b7.ra_in1k': _cfg(
  1460. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth',
  1461. hf_hub_id='timm/',
  1462. input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
  1463. 'tf_efficientnet_b8.ra_in1k': _cfg(
  1464. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth',
  1465. hf_hub_id='timm/',
  1466. input_size=(3, 672, 672), pool_size=(21, 21), crop_pct=0.954),
  1467. 'tf_efficientnet_b0.aa_in1k': _cfg(
  1468. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
  1469. hf_hub_id='timm/',
  1470. input_size=(3, 224, 224)),
  1471. 'tf_efficientnet_b1.aa_in1k': _cfg(
  1472. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth',
  1473. hf_hub_id='timm/',
  1474. input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
  1475. 'tf_efficientnet_b2.aa_in1k': _cfg(
  1476. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth',
  1477. hf_hub_id='timm/',
  1478. input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
  1479. 'tf_efficientnet_b3.aa_in1k': _cfg(
  1480. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth',
  1481. hf_hub_id='timm/',
  1482. input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
  1483. 'tf_efficientnet_b4.aa_in1k': _cfg(
  1484. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth',
  1485. hf_hub_id='timm/',
  1486. input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
  1487. 'tf_efficientnet_b5.aa_in1k': _cfg(
  1488. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_aa-99018a74.pth',
  1489. hf_hub_id='timm/',
  1490. input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
  1491. 'tf_efficientnet_b6.aa_in1k': _cfg(
  1492. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth',
  1493. hf_hub_id='timm/',
  1494. input_size=(3, 528, 528), pool_size=(17, 17), crop_pct=0.942),
  1495. 'tf_efficientnet_b7.aa_in1k': _cfg(
  1496. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_aa-076e3472.pth',
  1497. hf_hub_id='timm/',
  1498. input_size=(3, 600, 600), pool_size=(19, 19), crop_pct=0.949),
  1499. 'tf_efficientnet_b0.in1k': _cfg(
  1500. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548.pth',
  1501. hf_hub_id='timm/',
  1502. input_size=(3, 224, 224)),
  1503. 'tf_efficientnet_b1.in1k': _cfg(
  1504. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1-5c1377c4.pth',
  1505. hf_hub_id='timm/',
  1506. input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
  1507. 'tf_efficientnet_b2.in1k': _cfg(
  1508. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04.pth',
  1509. hf_hub_id='timm/',
  1510. input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
  1511. 'tf_efficientnet_b3.in1k': _cfg(
  1512. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth',
  1513. hf_hub_id='timm/',
  1514. input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
  1515. 'tf_efficientnet_b4.in1k': _cfg(
  1516. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed.pth',
  1517. hf_hub_id='timm/',
  1518. input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
  1519. 'tf_efficientnet_b5.in1k': _cfg(
  1520. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9.pth',
  1521. hf_hub_id='timm/',
  1522. input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
  1523. 'tf_efficientnet_es.in1k': _cfg(
  1524. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth',
  1525. hf_hub_id='timm/',
  1526. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1527. input_size=(3, 224, 224), ),
  1528. 'tf_efficientnet_em.in1k': _cfg(
  1529. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth',
  1530. hf_hub_id='timm/',
  1531. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1532. input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
  1533. 'tf_efficientnet_el.in1k': _cfg(
  1534. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth',
  1535. hf_hub_id='timm/',
  1536. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1537. input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
  1538. 'tf_efficientnet_cc_b0_4e.in1k': _cfg(
  1539. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth',
  1540. hf_hub_id='timm/',
  1541. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
  1542. 'tf_efficientnet_cc_b0_8e.in1k': _cfg(
  1543. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth',
  1544. hf_hub_id='timm/',
  1545. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
  1546. 'tf_efficientnet_cc_b1_8e.in1k': _cfg(
  1547. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth',
  1548. hf_hub_id='timm/',
  1549. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1550. input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
  1551. 'tf_efficientnet_lite0.in1k': _cfg(
  1552. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth',
  1553. hf_hub_id='timm/',
  1554. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1555. interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res
  1556. ),
  1557. 'tf_efficientnet_lite1.in1k': _cfg(
  1558. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth',
  1559. hf_hub_id='timm/',
  1560. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1561. input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882,
  1562. interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res
  1563. ),
  1564. 'tf_efficientnet_lite2.in1k': _cfg(
  1565. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth',
  1566. hf_hub_id='timm/',
  1567. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1568. input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890,
  1569. interpolation='bicubic', # should be bilinear but bicubic better match for TF bilinear at low res
  1570. ),
  1571. 'tf_efficientnet_lite3.in1k': _cfg(
  1572. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth',
  1573. hf_hub_id='timm/',
  1574. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1575. input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904, interpolation='bilinear'),
  1576. 'tf_efficientnet_lite4.in1k': _cfg(
  1577. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth',
  1578. hf_hub_id='timm/',
  1579. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1580. input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.920, interpolation='bilinear'),
  1581. 'tf_efficientnetv2_s.in21k_ft_in1k': _cfg(
  1582. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth',
  1583. hf_hub_id='timm/',
  1584. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1585. input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
  1586. 'tf_efficientnetv2_m.in21k_ft_in1k': _cfg(
  1587. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21ft1k-bf41664a.pth',
  1588. hf_hub_id='timm/',
  1589. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1590. input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  1591. 'tf_efficientnetv2_l.in21k_ft_in1k': _cfg(
  1592. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21ft1k-60127a9d.pth',
  1593. hf_hub_id='timm/',
  1594. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1595. input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  1596. 'tf_efficientnetv2_xl.in21k_ft_in1k': _cfg(
  1597. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_xl_in21ft1k-06c35c48.pth',
  1598. hf_hub_id='timm/',
  1599. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1600. input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  1601. 'tf_efficientnetv2_s.in1k': _cfg(
  1602. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s-eb54923e.pth',
  1603. hf_hub_id='timm/',
  1604. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1605. input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
  1606. 'tf_efficientnetv2_m.in1k': _cfg(
  1607. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth',
  1608. hf_hub_id='timm/',
  1609. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1610. input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  1611. 'tf_efficientnetv2_l.in1k': _cfg(
  1612. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth',
  1613. hf_hub_id='timm/',
  1614. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1615. input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  1616. 'tf_efficientnetv2_s.in21k': _cfg(
  1617. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth',
  1618. hf_hub_id='timm/',
  1619. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843,
  1620. input_size=(3, 300, 300), test_input_size=(3, 384, 384), pool_size=(10, 10), crop_pct=1.0),
  1621. 'tf_efficientnetv2_m.in21k': _cfg(
  1622. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21k-361418a2.pth',
  1623. hf_hub_id='timm/',
  1624. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843,
  1625. input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  1626. 'tf_efficientnetv2_l.in21k': _cfg(
  1627. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21k-91a19ec9.pth',
  1628. hf_hub_id='timm/',
  1629. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843,
  1630. input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  1631. 'tf_efficientnetv2_xl.in21k': _cfg(
  1632. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_xl_in21k-fd7e8abf.pth',
  1633. hf_hub_id='timm/',
  1634. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843,
  1635. input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
  1636. 'tf_efficientnetv2_b0.in1k': _cfg(
  1637. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b0-c7cc451f.pth',
  1638. hf_hub_id='timm/',
  1639. input_size=(3, 192, 192), test_input_size=(3, 224, 224), pool_size=(6, 6)),
  1640. 'tf_efficientnetv2_b1.in1k': _cfg(
  1641. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b1-be6e41b0.pth',
  1642. hf_hub_id='timm/',
  1643. input_size=(3, 192, 192), test_input_size=(3, 240, 240), pool_size=(6, 6), crop_pct=0.882),
  1644. 'tf_efficientnetv2_b2.in1k': _cfg(
  1645. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b2-847de54e.pth',
  1646. hf_hub_id='timm/',
  1647. input_size=(3, 208, 208), test_input_size=(3, 260, 260), pool_size=(7, 7), crop_pct=0.890),
  1648. 'tf_efficientnetv2_b3.in21k_ft_in1k': _cfg(
  1649. hf_hub_id='timm/',
  1650. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1651. input_size=(3, 240, 240), test_input_size=(3, 300, 300), pool_size=(8, 8), crop_pct=0.9, crop_mode='squash'),
  1652. 'tf_efficientnetv2_b3.in1k': _cfg(
  1653. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b3-57773f13.pth',
  1654. hf_hub_id='timm/',
  1655. input_size=(3, 240, 240), test_input_size=(3, 300, 300), pool_size=(8, 8), crop_pct=0.904),
  1656. 'tf_efficientnetv2_b3.in21k': _cfg(
  1657. hf_hub_id='timm/',
  1658. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, num_classes=21843,
  1659. input_size=(3, 240, 240), test_input_size=(3, 300, 300), pool_size=(8, 8), crop_pct=0.904),
  1660. 'mixnet_s.ft_in1k': _cfg(
  1661. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth',
  1662. hf_hub_id='timm/'),
  1663. 'mixnet_m.ft_in1k': _cfg(
  1664. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth',
  1665. hf_hub_id='timm/'),
  1666. 'mixnet_l.ft_in1k': _cfg(
  1667. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth',
  1668. hf_hub_id='timm/'),
  1669. 'mixnet_xl.ra_in1k': _cfg(
  1670. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth',
  1671. hf_hub_id='timm/'),
  1672. 'mixnet_xxl.untrained': _cfg(),
  1673. 'tf_mixnet_s.in1k': _cfg(
  1674. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth',
  1675. hf_hub_id='timm/'),
  1676. 'tf_mixnet_m.in1k': _cfg(
  1677. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth',
  1678. hf_hub_id='timm/'),
  1679. 'tf_mixnet_l.in1k': _cfg(
  1680. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth',
  1681. hf_hub_id='timm/'),
  1682. "tinynet_a.in1k": _cfg(
  1683. input_size=(3, 192, 192), pool_size=(6, 6), # int(224 * 0.86)
  1684. url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_a.pth',
  1685. hf_hub_id='timm/'),
  1686. "tinynet_b.in1k": _cfg(
  1687. input_size=(3, 188, 188), pool_size=(6, 6), # int(224 * 0.84)
  1688. url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_b.pth',
  1689. hf_hub_id='timm/'),
  1690. "tinynet_c.in1k": _cfg(
  1691. input_size=(3, 184, 184), pool_size=(6, 6), # int(224 * 0.825)
  1692. url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_c.pth',
  1693. hf_hub_id='timm/'),
  1694. "tinynet_d.in1k": _cfg(
  1695. input_size=(3, 152, 152), pool_size=(5, 5), # int(224 * 0.68)
  1696. url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_d.pth',
  1697. hf_hub_id='timm/'),
  1698. "tinynet_e.in1k": _cfg(
  1699. input_size=(3, 106, 106), pool_size=(4, 4), # int(224 * 0.475)
  1700. url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_e.pth',
  1701. hf_hub_id='timm/'),
  1702. 'mobilenet_edgetpu_100.untrained': _cfg(
  1703. # hf_hub_id='timm/',
  1704. input_size=(3, 224, 224), crop_pct=0.9),
  1705. 'mobilenet_edgetpu_v2_xs.untrained': _cfg(
  1706. # hf_hub_id='timm/',
  1707. input_size=(3, 224, 224), crop_pct=0.9),
  1708. 'mobilenet_edgetpu_v2_s.untrained': _cfg(
  1709. #hf_hub_id='timm/',
  1710. input_size=(3, 224, 224), crop_pct=0.9),
  1711. 'mobilenet_edgetpu_v2_m.ra4_e3600_r224_in1k': _cfg(
  1712. hf_hub_id='timm/',
  1713. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1714. crop_pct=0.9, test_input_size=(3, 256, 256), test_crop_pct=0.95,
  1715. ),
  1716. 'mobilenet_edgetpu_v2_l.untrained': _cfg(
  1717. #hf_hub_id='timm/',
  1718. input_size=(3, 224, 224), crop_pct=0.9),
  1719. "test_efficientnet.r160_in1k": _cfg(
  1720. hf_hub_id='timm/',
  1721. input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.95),
  1722. "test_efficientnet_ln.r160_in1k": _cfg(
  1723. hf_hub_id='timm/',
  1724. input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.95),
  1725. "test_efficientnet_gn.r160_in1k": _cfg(
  1726. hf_hub_id='timm/',
  1727. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1728. input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.95),
  1729. "test_efficientnet_evos.r160_in1k": _cfg(
  1730. hf_hub_id='timm/',
  1731. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  1732. input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.95),
  1733. })
  1734. @register_model
  1735. def mnasnet_050(pretrained=False, **kwargs) -> EfficientNet:
  1736. """ MNASNet B1, depth multiplier of 0.5. """
  1737. model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs)
  1738. return model
  1739. @register_model
  1740. def mnasnet_075(pretrained=False, **kwargs) -> EfficientNet:
  1741. """ MNASNet B1, depth multiplier of 0.75. """
  1742. model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs)
  1743. return model
  1744. @register_model
  1745. def mnasnet_100(pretrained=False, **kwargs) -> EfficientNet:
  1746. """ MNASNet B1, depth multiplier of 1.0. """
  1747. model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs)
  1748. return model
  1749. @register_model
  1750. def mnasnet_140(pretrained=False, **kwargs) -> EfficientNet:
  1751. """ MNASNet B1, depth multiplier of 1.4 """
  1752. model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs)
  1753. return model
  1754. @register_model
  1755. def semnasnet_050(pretrained=False, **kwargs) -> EfficientNet:
  1756. """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """
  1757. model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs)
  1758. return model
  1759. @register_model
  1760. def semnasnet_075(pretrained=False, **kwargs) -> EfficientNet:
  1761. """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """
  1762. model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs)
  1763. return model
  1764. @register_model
  1765. def semnasnet_100(pretrained=False, **kwargs) -> EfficientNet:
  1766. """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """
  1767. model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs)
  1768. return model
  1769. @register_model
  1770. def semnasnet_140(pretrained=False, **kwargs) -> EfficientNet:
  1771. """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """
  1772. model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs)
  1773. return model
  1774. @register_model
  1775. def mnasnet_small(pretrained=False, **kwargs) -> EfficientNet:
  1776. """ MNASNet Small, depth multiplier of 1.0. """
  1777. model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs)
  1778. return model
  1779. @register_model
  1780. def mobilenetv1_100(pretrained=False, **kwargs) -> EfficientNet:
  1781. """ MobileNet V1 """
  1782. model = _gen_mobilenet_v1('mobilenetv1_100', 1.0, pretrained=pretrained, **kwargs)
  1783. return model
  1784. @register_model
  1785. def mobilenetv1_100h(pretrained=False, **kwargs) -> EfficientNet:
  1786. """ MobileNet V1 """
  1787. model = _gen_mobilenet_v1('mobilenetv1_100h', 1.0, head_conv=True, pretrained=pretrained, **kwargs)
  1788. return model
  1789. @register_model
  1790. def mobilenetv1_125(pretrained=False, **kwargs) -> EfficientNet:
  1791. """ MobileNet V1 """
  1792. model = _gen_mobilenet_v1('mobilenetv1_125', 1.25, pretrained=pretrained, **kwargs)
  1793. return model
  1794. @register_model
  1795. def mobilenetv2_035(pretrained=False, **kwargs) -> EfficientNet:
  1796. """ MobileNet V2 w/ 0.35 channel multiplier """
  1797. model = _gen_mobilenet_v2('mobilenetv2_035', 0.35, pretrained=pretrained, **kwargs)
  1798. return model
  1799. @register_model
  1800. def mobilenetv2_050(pretrained=False, **kwargs) -> EfficientNet:
  1801. """ MobileNet V2 w/ 0.5 channel multiplier """
  1802. model = _gen_mobilenet_v2('mobilenetv2_050', 0.5, pretrained=pretrained, **kwargs)
  1803. return model
  1804. @register_model
  1805. def mobilenetv2_075(pretrained=False, **kwargs) -> EfficientNet:
  1806. """ MobileNet V2 w/ 0.75 channel multiplier """
  1807. model = _gen_mobilenet_v2('mobilenetv2_075', 0.75, pretrained=pretrained, **kwargs)
  1808. return model
  1809. @register_model
  1810. def mobilenetv2_100(pretrained=False, **kwargs) -> EfficientNet:
  1811. """ MobileNet V2 w/ 1.0 channel multiplier """
  1812. model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs)
  1813. return model
  1814. @register_model
  1815. def mobilenetv2_140(pretrained=False, **kwargs) -> EfficientNet:
  1816. """ MobileNet V2 w/ 1.4 channel multiplier """
  1817. model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs)
  1818. return model
  1819. @register_model
  1820. def mobilenetv2_110d(pretrained=False, **kwargs) -> EfficientNet:
  1821. """ MobileNet V2 w/ 1.1 channel, 1.2 depth multipliers"""
  1822. model = _gen_mobilenet_v2(
  1823. 'mobilenetv2_110d', 1.1, depth_multiplier=1.2, fix_stem_head=True, pretrained=pretrained, **kwargs)
  1824. return model
  1825. @register_model
  1826. def mobilenetv2_120d(pretrained=False, **kwargs) -> EfficientNet:
  1827. """ MobileNet V2 w/ 1.2 channel, 1.4 depth multipliers """
  1828. model = _gen_mobilenet_v2(
  1829. 'mobilenetv2_120d', 1.2, depth_multiplier=1.4, fix_stem_head=True, pretrained=pretrained, **kwargs)
  1830. return model
  1831. @register_model
  1832. def fbnetc_100(pretrained=False, **kwargs) -> EfficientNet:
  1833. """ FBNet-C """
  1834. if pretrained:
  1835. # pretrained model trained with non-default BN epsilon
  1836. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  1837. model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs)
  1838. return model
  1839. @register_model
  1840. def spnasnet_100(pretrained=False, **kwargs) -> EfficientNet:
  1841. """ Single-Path NAS Pixel1"""
  1842. model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs)
  1843. return model
  1844. @register_model
  1845. def efficientnet_b0(pretrained=False, **kwargs) -> EfficientNet:
  1846. """ EfficientNet-B0 """
  1847. # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
  1848. model = _gen_efficientnet(
  1849. 'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
  1850. return model
  1851. @register_model
  1852. def efficientnet_b1(pretrained=False, **kwargs) -> EfficientNet:
  1853. """ EfficientNet-B1 """
  1854. # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
  1855. model = _gen_efficientnet(
  1856. 'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
  1857. return model
  1858. @register_model
  1859. def efficientnet_b2(pretrained=False, **kwargs) -> EfficientNet:
  1860. """ EfficientNet-B2 """
  1861. # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
  1862. model = _gen_efficientnet(
  1863. 'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
  1864. return model
  1865. @register_model
  1866. def efficientnet_b3(pretrained=False, **kwargs) -> EfficientNet:
  1867. """ EfficientNet-B3 """
  1868. # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
  1869. model = _gen_efficientnet(
  1870. 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
  1871. return model
  1872. @register_model
  1873. def efficientnet_b4(pretrained=False, **kwargs) -> EfficientNet:
  1874. """ EfficientNet-B4 """
  1875. # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2
  1876. model = _gen_efficientnet(
  1877. 'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
  1878. return model
  1879. @register_model
  1880. def efficientnet_b5(pretrained=False, **kwargs) -> EfficientNet:
  1881. """ EfficientNet-B5 """
  1882. # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2
  1883. model = _gen_efficientnet(
  1884. 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
  1885. return model
  1886. @register_model
  1887. def efficientnet_b6(pretrained=False, **kwargs) -> EfficientNet:
  1888. """ EfficientNet-B6 """
  1889. # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2
  1890. model = _gen_efficientnet(
  1891. 'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
  1892. return model
  1893. @register_model
  1894. def efficientnet_b7(pretrained=False, **kwargs) -> EfficientNet:
  1895. """ EfficientNet-B7 """
  1896. # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2
  1897. model = _gen_efficientnet(
  1898. 'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
  1899. return model
  1900. @register_model
  1901. def efficientnet_b8(pretrained=False, **kwargs) -> EfficientNet:
  1902. """ EfficientNet-B8 """
  1903. # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2
  1904. model = _gen_efficientnet(
  1905. 'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
  1906. return model
  1907. @register_model
  1908. def efficientnet_l2(pretrained=False, **kwargs) -> EfficientNet:
  1909. """ EfficientNet-L2."""
  1910. # NOTE for train, drop_rate should be 0.5, drop_path_rate should be 0.2
  1911. model = _gen_efficientnet(
  1912. 'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
  1913. return model
  1914. # FIXME experimental group cong / GroupNorm / EvoNorm experiments
  1915. @register_model
  1916. def efficientnet_b0_gn(pretrained=False, **kwargs) -> EfficientNet:
  1917. """ EfficientNet-B0 + GroupNorm"""
  1918. model = _gen_efficientnet(
  1919. 'efficientnet_b0_gn', norm_layer=partial(GroupNormAct, group_size=8), pretrained=pretrained, **kwargs)
  1920. return model
  1921. @register_model
  1922. def efficientnet_b0_g8_gn(pretrained=False, **kwargs) -> EfficientNet:
  1923. """ EfficientNet-B0 w/ group conv + GroupNorm"""
  1924. model = _gen_efficientnet(
  1925. 'efficientnet_b0_g8_gn', group_size=8, norm_layer=partial(GroupNormAct, group_size=8),
  1926. pretrained=pretrained, **kwargs)
  1927. return model
  1928. @register_model
  1929. def efficientnet_b0_g16_evos(pretrained=False, **kwargs) -> EfficientNet:
  1930. """ EfficientNet-B0 w/ group 16 conv + EvoNorm"""
  1931. model = _gen_efficientnet(
  1932. 'efficientnet_b0_g16_evos', group_size=16, channel_divisor=16,
  1933. pretrained=pretrained, **kwargs) #norm_layer=partial(EvoNorm2dS0, group_size=16),
  1934. return model
  1935. @register_model
  1936. def efficientnet_b3_gn(pretrained=False, **kwargs) -> EfficientNet:
  1937. """ EfficientNet-B3 w/ GroupNorm """
  1938. # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
  1939. model = _gen_efficientnet(
  1940. 'efficientnet_b3_gn', channel_multiplier=1.2, depth_multiplier=1.4, channel_divisor=16,
  1941. norm_layer=partial(GroupNormAct, group_size=16), pretrained=pretrained, **kwargs)
  1942. return model
  1943. @register_model
  1944. def efficientnet_b3_g8_gn(pretrained=False, **kwargs) -> EfficientNet:
  1945. """ EfficientNet-B3 w/ grouped conv + BN"""
  1946. # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
  1947. model = _gen_efficientnet(
  1948. 'efficientnet_b3_g8_gn', channel_multiplier=1.2, depth_multiplier=1.4, group_size=8, channel_divisor=16,
  1949. norm_layer=partial(GroupNormAct, group_size=16), pretrained=pretrained, **kwargs)
  1950. return model
  1951. @register_model
  1952. def efficientnet_blur_b0(pretrained=False, **kwargs) -> EfficientNet:
  1953. """ EfficientNet-B0 w/ BlurPool """
  1954. # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
  1955. model = _gen_efficientnet(
  1956. 'efficientnet_blur_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained,
  1957. aa_layer='blurpc', **kwargs
  1958. )
  1959. return model
  1960. @register_model
  1961. def efficientnet_es(pretrained=False, **kwargs) -> EfficientNet:
  1962. """ EfficientNet-Edge Small. """
  1963. model = _gen_efficientnet_edge(
  1964. 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
  1965. return model
  1966. @register_model
  1967. def efficientnet_es_pruned(pretrained=False, **kwargs) -> EfficientNet:
  1968. """ EfficientNet-Edge Small Pruned. For more info: https://github.com/DeGirum/pruned-models/releases/tag/efficientnet_v1.0"""
  1969. model = _gen_efficientnet_edge(
  1970. 'efficientnet_es_pruned', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
  1971. return model
  1972. @register_model
  1973. def efficientnet_em(pretrained=False, **kwargs) -> EfficientNet:
  1974. """ EfficientNet-Edge-Medium. """
  1975. model = _gen_efficientnet_edge(
  1976. 'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
  1977. return model
  1978. @register_model
  1979. def efficientnet_el(pretrained=False, **kwargs) -> EfficientNet:
  1980. """ EfficientNet-Edge-Large. """
  1981. model = _gen_efficientnet_edge(
  1982. 'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
  1983. return model
  1984. @register_model
  1985. def efficientnet_el_pruned(pretrained=False, **kwargs) -> EfficientNet:
  1986. """ EfficientNet-Edge-Large pruned. For more info: https://github.com/DeGirum/pruned-models/releases/tag/efficientnet_v1.0"""
  1987. model = _gen_efficientnet_edge(
  1988. 'efficientnet_el_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
  1989. return model
  1990. @register_model
  1991. def efficientnet_cc_b0_4e(pretrained=False, **kwargs) -> EfficientNet:
  1992. """ EfficientNet-CondConv-B0 w/ 8 Experts """
  1993. # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
  1994. model = _gen_efficientnet_condconv(
  1995. 'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
  1996. return model
  1997. @register_model
  1998. def efficientnet_cc_b0_8e(pretrained=False, **kwargs) -> EfficientNet:
  1999. """ EfficientNet-CondConv-B0 w/ 8 Experts """
  2000. # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
  2001. model = _gen_efficientnet_condconv(
  2002. 'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
  2003. pretrained=pretrained, **kwargs)
  2004. return model
  2005. @register_model
  2006. def efficientnet_cc_b1_8e(pretrained=False, **kwargs) -> EfficientNet:
  2007. """ EfficientNet-CondConv-B1 w/ 8 Experts """
  2008. # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
  2009. model = _gen_efficientnet_condconv(
  2010. 'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
  2011. pretrained=pretrained, **kwargs)
  2012. return model
  2013. @register_model
  2014. def efficientnet_lite0(pretrained=False, **kwargs) -> EfficientNet:
  2015. """ EfficientNet-Lite0 """
  2016. # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
  2017. model = _gen_efficientnet_lite(
  2018. 'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
  2019. return model
  2020. @register_model
  2021. def efficientnet_lite1(pretrained=False, **kwargs) -> EfficientNet:
  2022. """ EfficientNet-Lite1 """
  2023. # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
  2024. model = _gen_efficientnet_lite(
  2025. 'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
  2026. return model
  2027. @register_model
  2028. def efficientnet_lite2(pretrained=False, **kwargs) -> EfficientNet:
  2029. """ EfficientNet-Lite2 """
  2030. # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
  2031. model = _gen_efficientnet_lite(
  2032. 'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
  2033. return model
  2034. @register_model
  2035. def efficientnet_lite3(pretrained=False, **kwargs) -> EfficientNet:
  2036. """ EfficientNet-Lite3 """
  2037. # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
  2038. model = _gen_efficientnet_lite(
  2039. 'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
  2040. return model
  2041. @register_model
  2042. def efficientnet_lite4(pretrained=False, **kwargs) -> EfficientNet:
  2043. """ EfficientNet-Lite4 """
  2044. # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2
  2045. model = _gen_efficientnet_lite(
  2046. 'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
  2047. return model
  2048. @register_model
  2049. def efficientnet_b1_pruned(pretrained=False, **kwargs) -> EfficientNet:
  2050. """ EfficientNet-B1 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
  2051. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2052. kwargs.setdefault('pad_type', 'same')
  2053. variant = 'efficientnet_b1_pruned'
  2054. model = _gen_efficientnet(
  2055. variant, channel_multiplier=1.0, depth_multiplier=1.1, pruned=True, pretrained=pretrained, **kwargs)
  2056. return model
  2057. @register_model
  2058. def efficientnet_b2_pruned(pretrained=False, **kwargs) -> EfficientNet:
  2059. """ EfficientNet-B2 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
  2060. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2061. kwargs.setdefault('pad_type', 'same')
  2062. model = _gen_efficientnet(
  2063. 'efficientnet_b2_pruned', channel_multiplier=1.1, depth_multiplier=1.2, pruned=True,
  2064. pretrained=pretrained, **kwargs)
  2065. return model
  2066. @register_model
  2067. def efficientnet_b3_pruned(pretrained=False, **kwargs) -> EfficientNet:
  2068. """ EfficientNet-B3 Pruned. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf """
  2069. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2070. kwargs.setdefault('pad_type', 'same')
  2071. model = _gen_efficientnet(
  2072. 'efficientnet_b3_pruned', channel_multiplier=1.2, depth_multiplier=1.4, pruned=True,
  2073. pretrained=pretrained, **kwargs)
  2074. return model
  2075. @register_model
  2076. def efficientnetv2_rw_t(pretrained=False, **kwargs) -> EfficientNet:
  2077. """ EfficientNet-V2 Tiny (Custom variant, tiny not in paper). """
  2078. model = _gen_efficientnetv2_s(
  2079. 'efficientnetv2_rw_t', channel_multiplier=0.8, depth_multiplier=0.9, rw=False, pretrained=pretrained, **kwargs)
  2080. return model
  2081. @register_model
  2082. def gc_efficientnetv2_rw_t(pretrained=False, **kwargs) -> EfficientNet:
  2083. """ EfficientNet-V2 Tiny w/ Global Context Attn (Custom variant, tiny not in paper). """
  2084. model = _gen_efficientnetv2_s(
  2085. 'gc_efficientnetv2_rw_t', channel_multiplier=0.8, depth_multiplier=0.9,
  2086. rw=False, se_layer='gc', pretrained=pretrained, **kwargs)
  2087. return model
  2088. @register_model
  2089. def efficientnetv2_rw_s(pretrained=False, **kwargs) -> EfficientNet:
  2090. """ EfficientNet-V2 Small (RW variant).
  2091. NOTE: This is my initial (pre official code release) w/ some differences.
  2092. See efficientnetv2_s and tf_efficientnetv2_s for versions that match the official w/ PyTorch vs TF padding
  2093. """
  2094. model = _gen_efficientnetv2_s('efficientnetv2_rw_s', rw=True, pretrained=pretrained, **kwargs)
  2095. return model
  2096. @register_model
  2097. def efficientnetv2_rw_m(pretrained=False, **kwargs) -> EfficientNet:
  2098. """ EfficientNet-V2 Medium (RW variant).
  2099. """
  2100. model = _gen_efficientnetv2_s(
  2101. 'efficientnetv2_rw_m', channel_multiplier=1.2, depth_multiplier=(1.2,) * 4 + (1.6,) * 2, rw=True,
  2102. pretrained=pretrained, **kwargs)
  2103. return model
  2104. @register_model
  2105. def efficientnetv2_s(pretrained=False, **kwargs) -> EfficientNet:
  2106. """ EfficientNet-V2 Small. """
  2107. model = _gen_efficientnetv2_s('efficientnetv2_s', pretrained=pretrained, **kwargs)
  2108. return model
  2109. @register_model
  2110. def efficientnetv2_m(pretrained=False, **kwargs) -> EfficientNet:
  2111. """ EfficientNet-V2 Medium. """
  2112. model = _gen_efficientnetv2_m('efficientnetv2_m', pretrained=pretrained, **kwargs)
  2113. return model
  2114. @register_model
  2115. def efficientnetv2_l(pretrained=False, **kwargs) -> EfficientNet:
  2116. """ EfficientNet-V2 Large. """
  2117. model = _gen_efficientnetv2_l('efficientnetv2_l', pretrained=pretrained, **kwargs)
  2118. return model
  2119. @register_model
  2120. def efficientnetv2_xl(pretrained=False, **kwargs) -> EfficientNet:
  2121. """ EfficientNet-V2 Xtra-Large. """
  2122. model = _gen_efficientnetv2_xl('efficientnetv2_xl', pretrained=pretrained, **kwargs)
  2123. return model
  2124. @register_model
  2125. def tf_efficientnet_b0(pretrained=False, **kwargs) -> EfficientNet:
  2126. """ EfficientNet-B0. Tensorflow compatible variant """
  2127. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2128. kwargs.setdefault('pad_type', 'same')
  2129. model = _gen_efficientnet(
  2130. 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
  2131. return model
  2132. @register_model
  2133. def tf_efficientnet_b1(pretrained=False, **kwargs) -> EfficientNet:
  2134. """ EfficientNet-B1. Tensorflow compatible variant """
  2135. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2136. kwargs.setdefault('pad_type', 'same')
  2137. model = _gen_efficientnet(
  2138. 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
  2139. return model
  2140. @register_model
  2141. def tf_efficientnet_b2(pretrained=False, **kwargs) -> EfficientNet:
  2142. """ EfficientNet-B2. Tensorflow compatible variant """
  2143. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2144. kwargs.setdefault('pad_type', 'same')
  2145. model = _gen_efficientnet(
  2146. 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
  2147. return model
  2148. @register_model
  2149. def tf_efficientnet_b3(pretrained=False, **kwargs) -> EfficientNet:
  2150. """ EfficientNet-B3. Tensorflow compatible variant """
  2151. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2152. kwargs.setdefault('pad_type', 'same')
  2153. model = _gen_efficientnet(
  2154. 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
  2155. return model
  2156. @register_model
  2157. def tf_efficientnet_b4(pretrained=False, **kwargs) -> EfficientNet:
  2158. """ EfficientNet-B4. Tensorflow compatible variant """
  2159. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2160. kwargs.setdefault('pad_type', 'same')
  2161. model = _gen_efficientnet(
  2162. 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
  2163. return model
  2164. @register_model
  2165. def tf_efficientnet_b5(pretrained=False, **kwargs) -> EfficientNet:
  2166. """ EfficientNet-B5. Tensorflow compatible variant """
  2167. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2168. kwargs.setdefault('pad_type', 'same')
  2169. model = _gen_efficientnet(
  2170. 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
  2171. return model
  2172. @register_model
  2173. def tf_efficientnet_b6(pretrained=False, **kwargs) -> EfficientNet:
  2174. """ EfficientNet-B6. Tensorflow compatible variant """
  2175. # NOTE for train, drop_rate should be 0.5
  2176. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2177. kwargs.setdefault('pad_type', 'same')
  2178. model = _gen_efficientnet(
  2179. 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs)
  2180. return model
  2181. @register_model
  2182. def tf_efficientnet_b7(pretrained=False, **kwargs) -> EfficientNet:
  2183. """ EfficientNet-B7. Tensorflow compatible variant """
  2184. # NOTE for train, drop_rate should be 0.5
  2185. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2186. kwargs.setdefault('pad_type', 'same')
  2187. model = _gen_efficientnet(
  2188. 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs)
  2189. return model
  2190. @register_model
  2191. def tf_efficientnet_b8(pretrained=False, **kwargs) -> EfficientNet:
  2192. """ EfficientNet-B8. Tensorflow compatible variant """
  2193. # NOTE for train, drop_rate should be 0.5
  2194. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2195. kwargs.setdefault('pad_type', 'same')
  2196. model = _gen_efficientnet(
  2197. 'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs)
  2198. return model
  2199. @register_model
  2200. def tf_efficientnet_l2(pretrained=False, **kwargs) -> EfficientNet:
  2201. """ EfficientNet-L2 NoisyStudent. Tensorflow compatible variant """
  2202. # NOTE for train, drop_rate should be 0.5
  2203. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2204. kwargs.setdefault('pad_type', 'same')
  2205. model = _gen_efficientnet(
  2206. 'tf_efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs)
  2207. return model
  2208. @register_model
  2209. def tf_efficientnet_es(pretrained=False, **kwargs) -> EfficientNet:
  2210. """ EfficientNet-Edge Small. Tensorflow compatible variant """
  2211. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2212. kwargs.setdefault('pad_type', 'same')
  2213. model = _gen_efficientnet_edge(
  2214. 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
  2215. return model
  2216. @register_model
  2217. def tf_efficientnet_em(pretrained=False, **kwargs) -> EfficientNet:
  2218. """ EfficientNet-Edge-Medium. Tensorflow compatible variant """
  2219. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2220. kwargs.setdefault('pad_type', 'same')
  2221. model = _gen_efficientnet_edge(
  2222. 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
  2223. return model
  2224. @register_model
  2225. def tf_efficientnet_el(pretrained=False, **kwargs) -> EfficientNet:
  2226. """ EfficientNet-Edge-Large. Tensorflow compatible variant """
  2227. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2228. kwargs.setdefault('pad_type', 'same')
  2229. model = _gen_efficientnet_edge(
  2230. 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
  2231. return model
  2232. @register_model
  2233. def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs) -> EfficientNet:
  2234. """ EfficientNet-CondConv-B0 w/ 4 Experts. Tensorflow compatible variant """
  2235. # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
  2236. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2237. kwargs.setdefault('pad_type', 'same')
  2238. model = _gen_efficientnet_condconv(
  2239. 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
  2240. return model
  2241. @register_model
  2242. def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs) -> EfficientNet:
  2243. """ EfficientNet-CondConv-B0 w/ 8 Experts. Tensorflow compatible variant """
  2244. # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
  2245. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2246. kwargs.setdefault('pad_type', 'same')
  2247. model = _gen_efficientnet_condconv(
  2248. 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2,
  2249. pretrained=pretrained, **kwargs)
  2250. return model
  2251. @register_model
  2252. def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs) -> EfficientNet:
  2253. """ EfficientNet-CondConv-B1 w/ 8 Experts. Tensorflow compatible variant """
  2254. # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
  2255. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2256. kwargs.setdefault('pad_type', 'same')
  2257. model = _gen_efficientnet_condconv(
  2258. 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2,
  2259. pretrained=pretrained, **kwargs)
  2260. return model
  2261. @register_model
  2262. def tf_efficientnet_lite0(pretrained=False, **kwargs) -> EfficientNet:
  2263. """ EfficientNet-Lite0 """
  2264. # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
  2265. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2266. kwargs.setdefault('pad_type', 'same')
  2267. model = _gen_efficientnet_lite(
  2268. 'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
  2269. return model
  2270. @register_model
  2271. def tf_efficientnet_lite1(pretrained=False, **kwargs) -> EfficientNet:
  2272. """ EfficientNet-Lite1 """
  2273. # NOTE for train, drop_rate should be 0.2, drop_path_rate should be 0.2
  2274. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2275. kwargs.setdefault('pad_type', 'same')
  2276. model = _gen_efficientnet_lite(
  2277. 'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
  2278. return model
  2279. @register_model
  2280. def tf_efficientnet_lite2(pretrained=False, **kwargs) -> EfficientNet:
  2281. """ EfficientNet-Lite2 """
  2282. # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
  2283. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2284. kwargs.setdefault('pad_type', 'same')
  2285. model = _gen_efficientnet_lite(
  2286. 'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
  2287. return model
  2288. @register_model
  2289. def tf_efficientnet_lite3(pretrained=False, **kwargs) -> EfficientNet:
  2290. """ EfficientNet-Lite3 """
  2291. # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
  2292. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2293. kwargs.setdefault('pad_type', 'same')
  2294. model = _gen_efficientnet_lite(
  2295. 'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
  2296. return model
  2297. @register_model
  2298. def tf_efficientnet_lite4(pretrained=False, **kwargs) -> EfficientNet:
  2299. """ EfficientNet-Lite4 """
  2300. # NOTE for train, drop_rate should be 0.4, drop_path_rate should be 0.2
  2301. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2302. kwargs.setdefault('pad_type', 'same')
  2303. model = _gen_efficientnet_lite(
  2304. 'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs)
  2305. return model
  2306. @register_model
  2307. def tf_efficientnetv2_s(pretrained=False, **kwargs) -> EfficientNet:
  2308. """ EfficientNet-V2 Small. Tensorflow compatible variant """
  2309. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2310. kwargs.setdefault('pad_type', 'same')
  2311. model = _gen_efficientnetv2_s('tf_efficientnetv2_s', pretrained=pretrained, **kwargs)
  2312. return model
  2313. @register_model
  2314. def tf_efficientnetv2_m(pretrained=False, **kwargs) -> EfficientNet:
  2315. """ EfficientNet-V2 Medium. Tensorflow compatible variant """
  2316. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2317. kwargs.setdefault('pad_type', 'same')
  2318. model = _gen_efficientnetv2_m('tf_efficientnetv2_m', pretrained=pretrained, **kwargs)
  2319. return model
  2320. @register_model
  2321. def tf_efficientnetv2_l(pretrained=False, **kwargs) -> EfficientNet:
  2322. """ EfficientNet-V2 Large. Tensorflow compatible variant """
  2323. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2324. kwargs.setdefault('pad_type', 'same')
  2325. model = _gen_efficientnetv2_l('tf_efficientnetv2_l', pretrained=pretrained, **kwargs)
  2326. return model
  2327. @register_model
  2328. def tf_efficientnetv2_xl(pretrained=False, **kwargs) -> EfficientNet:
  2329. """ EfficientNet-V2 Xtra-Large. Tensorflow compatible variant
  2330. """
  2331. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2332. kwargs.setdefault('pad_type', 'same')
  2333. model = _gen_efficientnetv2_xl('tf_efficientnetv2_xl', pretrained=pretrained, **kwargs)
  2334. return model
  2335. @register_model
  2336. def tf_efficientnetv2_b0(pretrained=False, **kwargs) -> EfficientNet:
  2337. """ EfficientNet-V2-B0. Tensorflow compatible variant """
  2338. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2339. kwargs.setdefault('pad_type', 'same')
  2340. model = _gen_efficientnetv2_base('tf_efficientnetv2_b0', pretrained=pretrained, **kwargs)
  2341. return model
  2342. @register_model
  2343. def tf_efficientnetv2_b1(pretrained=False, **kwargs) -> EfficientNet:
  2344. """ EfficientNet-V2-B1. Tensorflow compatible variant """
  2345. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2346. kwargs.setdefault('pad_type', 'same')
  2347. model = _gen_efficientnetv2_base(
  2348. 'tf_efficientnetv2_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs)
  2349. return model
  2350. @register_model
  2351. def tf_efficientnetv2_b2(pretrained=False, **kwargs) -> EfficientNet:
  2352. """ EfficientNet-V2-B2. Tensorflow compatible variant """
  2353. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2354. kwargs.setdefault('pad_type', 'same')
  2355. model = _gen_efficientnetv2_base(
  2356. 'tf_efficientnetv2_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
  2357. return model
  2358. @register_model
  2359. def tf_efficientnetv2_b3(pretrained=False, **kwargs) -> EfficientNet:
  2360. """ EfficientNet-V2-B3. Tensorflow compatible variant """
  2361. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2362. kwargs.setdefault('pad_type', 'same')
  2363. model = _gen_efficientnetv2_base(
  2364. 'tf_efficientnetv2_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
  2365. return model
  2366. @register_model
  2367. def efficientnet_x_b3(pretrained=False, **kwargs) -> EfficientNet:
  2368. """ EfficientNet-B3 """
  2369. # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
  2370. model = _gen_efficientnet_x(
  2371. 'efficientnet_x_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs)
  2372. return model
  2373. @register_model
  2374. def efficientnet_x_b5(pretrained=False, **kwargs) -> EfficientNet:
  2375. """ EfficientNet-B5 """
  2376. model = _gen_efficientnet_x(
  2377. 'efficientnet_x_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs)
  2378. return model
  2379. @register_model
  2380. def efficientnet_h_b5(pretrained=False, **kwargs) -> EfficientNet:
  2381. """ EfficientNet-B5 """
  2382. model = _gen_efficientnet_x(
  2383. 'efficientnet_h_b5', channel_multiplier=1.92, depth_multiplier=2.2, version=2, pretrained=pretrained, **kwargs)
  2384. return model
  2385. @register_model
  2386. def mixnet_s(pretrained=False, **kwargs) -> EfficientNet:
  2387. """Creates a MixNet Small model.
  2388. """
  2389. model = _gen_mixnet_s(
  2390. 'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
  2391. return model
  2392. @register_model
  2393. def mixnet_m(pretrained=False, **kwargs) -> EfficientNet:
  2394. """Creates a MixNet Medium model.
  2395. """
  2396. model = _gen_mixnet_m(
  2397. 'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
  2398. return model
  2399. @register_model
  2400. def mixnet_l(pretrained=False, **kwargs) -> EfficientNet:
  2401. """Creates a MixNet Large model.
  2402. """
  2403. model = _gen_mixnet_m(
  2404. 'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
  2405. return model
  2406. @register_model
  2407. def mixnet_xl(pretrained=False, **kwargs) -> EfficientNet:
  2408. """Creates a MixNet Extra-Large model.
  2409. Not a paper spec, experimental def by RW w/ depth scaling.
  2410. """
  2411. model = _gen_mixnet_m(
  2412. 'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs)
  2413. return model
  2414. @register_model
  2415. def mixnet_xxl(pretrained=False, **kwargs) -> EfficientNet:
  2416. """Creates a MixNet Double Extra Large model.
  2417. Not a paper spec, experimental def by RW w/ depth scaling.
  2418. """
  2419. model = _gen_mixnet_m(
  2420. 'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs)
  2421. return model
  2422. @register_model
  2423. def tf_mixnet_s(pretrained=False, **kwargs) -> EfficientNet:
  2424. """Creates a MixNet Small model. Tensorflow compatible variant
  2425. """
  2426. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2427. kwargs.setdefault('pad_type', 'same')
  2428. model = _gen_mixnet_s(
  2429. 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
  2430. return model
  2431. @register_model
  2432. def tf_mixnet_m(pretrained=False, **kwargs) -> EfficientNet:
  2433. """Creates a MixNet Medium model. Tensorflow compatible variant
  2434. """
  2435. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2436. kwargs.setdefault('pad_type', 'same')
  2437. model = _gen_mixnet_m(
  2438. 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs)
  2439. return model
  2440. @register_model
  2441. def tf_mixnet_l(pretrained=False, **kwargs) -> EfficientNet:
  2442. """Creates a MixNet Large model. Tensorflow compatible variant
  2443. """
  2444. kwargs.setdefault('bn_eps', BN_EPS_TF_DEFAULT)
  2445. kwargs.setdefault('pad_type', 'same')
  2446. model = _gen_mixnet_m(
  2447. 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs)
  2448. return model
  2449. @register_model
  2450. def tinynet_a(pretrained=False, **kwargs) -> EfficientNet:
  2451. model = _gen_tinynet('tinynet_a', 1.0, 1.2, pretrained=pretrained, **kwargs)
  2452. return model
  2453. @register_model
  2454. def tinynet_b(pretrained=False, **kwargs) -> EfficientNet:
  2455. model = _gen_tinynet('tinynet_b', 0.75, 1.1, pretrained=pretrained, **kwargs)
  2456. return model
  2457. @register_model
  2458. def tinynet_c(pretrained=False, **kwargs) -> EfficientNet:
  2459. model = _gen_tinynet('tinynet_c', 0.54, 0.85, pretrained=pretrained, **kwargs)
  2460. return model
  2461. @register_model
  2462. def tinynet_d(pretrained=False, **kwargs) -> EfficientNet:
  2463. model = _gen_tinynet('tinynet_d', 0.54, 0.695, pretrained=pretrained, **kwargs)
  2464. return model
  2465. @register_model
  2466. def tinynet_e(pretrained=False, **kwargs) -> EfficientNet:
  2467. model = _gen_tinynet('tinynet_e', 0.51, 0.6, pretrained=pretrained, **kwargs)
  2468. return model
  2469. @register_model
  2470. def mobilenet_edgetpu_100(pretrained=False, **kwargs) -> EfficientNet:
  2471. """ MobileNet-EdgeTPU-v1 100. """
  2472. model = _gen_mobilenet_edgetpu('mobilenet_edgetpu_100', pretrained=pretrained, **kwargs)
  2473. return model
  2474. @register_model
  2475. def mobilenet_edgetpu_v2_xs(pretrained=False, **kwargs) -> EfficientNet:
  2476. """ MobileNet-EdgeTPU-v2 Extra Small. """
  2477. model = _gen_mobilenet_edgetpu('mobilenet_edgetpu_v2_xs', pretrained=pretrained, **kwargs)
  2478. return model
  2479. @register_model
  2480. def mobilenet_edgetpu_v2_s(pretrained=False, **kwargs) -> EfficientNet:
  2481. """ MobileNet-EdgeTPU-v2 Small. """
  2482. model = _gen_mobilenet_edgetpu('mobilenet_edgetpu_v2_s', pretrained=pretrained, **kwargs)
  2483. return model
  2484. @register_model
  2485. def mobilenet_edgetpu_v2_m(pretrained=False, **kwargs) -> EfficientNet:
  2486. """ MobileNet-EdgeTPU-v2 Medium. """
  2487. model = _gen_mobilenet_edgetpu('mobilenet_edgetpu_v2_m', pretrained=pretrained, **kwargs)
  2488. return model
  2489. @register_model
  2490. def mobilenet_edgetpu_v2_l(pretrained=False, **kwargs) -> EfficientNet:
  2491. """ MobileNet-EdgeTPU-v2 Large. """
  2492. model = _gen_mobilenet_edgetpu('mobilenet_edgetpu_v2_l', pretrained=pretrained, **kwargs)
  2493. return model
  2494. @register_model
  2495. def test_efficientnet(pretrained=False, **kwargs) -> EfficientNet:
  2496. model = _gen_test_efficientnet('test_efficientnet', pretrained=pretrained, **kwargs)
  2497. return model
  2498. @register_model
  2499. def test_efficientnet_gn(pretrained=False, **kwargs) -> EfficientNet:
  2500. model = _gen_test_efficientnet(
  2501. 'test_efficientnet_gn',
  2502. pretrained=pretrained,
  2503. norm_layer=kwargs.pop('norm_layer', partial(GroupNormAct, group_size=8)),
  2504. **kwargs
  2505. )
  2506. return model
  2507. @register_model
  2508. def test_efficientnet_ln(pretrained=False, **kwargs) -> EfficientNet:
  2509. model = _gen_test_efficientnet(
  2510. 'test_efficientnet_ln',
  2511. pretrained=pretrained,
  2512. norm_layer=kwargs.pop('norm_layer', LayerNormAct2d),
  2513. **kwargs
  2514. )
  2515. return model
  2516. @register_model
  2517. def test_efficientnet_evos(pretrained=False, **kwargs) -> EfficientNet:
  2518. model = _gen_test_efficientnet(
  2519. 'test_efficientnet_evos',
  2520. pretrained=pretrained,
  2521. norm_layer=kwargs.pop('norm_layer', partial(EvoNorm2dS0, group_size=8)),
  2522. **kwargs
  2523. )
  2524. return model
  2525. register_model_deprecations(__name__, {
  2526. 'tf_efficientnet_b0_ap': 'tf_efficientnet_b0.ap_in1k',
  2527. 'tf_efficientnet_b1_ap': 'tf_efficientnet_b1.ap_in1k',
  2528. 'tf_efficientnet_b2_ap': 'tf_efficientnet_b2.ap_in1k',
  2529. 'tf_efficientnet_b3_ap': 'tf_efficientnet_b3.ap_in1k',
  2530. 'tf_efficientnet_b4_ap': 'tf_efficientnet_b4.ap_in1k',
  2531. 'tf_efficientnet_b5_ap': 'tf_efficientnet_b5.ap_in1k',
  2532. 'tf_efficientnet_b6_ap': 'tf_efficientnet_b6.ap_in1k',
  2533. 'tf_efficientnet_b7_ap': 'tf_efficientnet_b7.ap_in1k',
  2534. 'tf_efficientnet_b8_ap': 'tf_efficientnet_b8.ap_in1k',
  2535. 'tf_efficientnet_b0_ns': 'tf_efficientnet_b0.ns_jft_in1k',
  2536. 'tf_efficientnet_b1_ns': 'tf_efficientnet_b1.ns_jft_in1k',
  2537. 'tf_efficientnet_b2_ns': 'tf_efficientnet_b2.ns_jft_in1k',
  2538. 'tf_efficientnet_b3_ns': 'tf_efficientnet_b3.ns_jft_in1k',
  2539. 'tf_efficientnet_b4_ns': 'tf_efficientnet_b4.ns_jft_in1k',
  2540. 'tf_efficientnet_b5_ns': 'tf_efficientnet_b5.ns_jft_in1k',
  2541. 'tf_efficientnet_b6_ns': 'tf_efficientnet_b6.ns_jft_in1k',
  2542. 'tf_efficientnet_b7_ns': 'tf_efficientnet_b7.ns_jft_in1k',
  2543. 'tf_efficientnet_l2_ns_475': 'tf_efficientnet_l2.ns_jft_in1k_475',
  2544. 'tf_efficientnet_l2_ns': 'tf_efficientnet_l2.ns_jft_in1k',
  2545. 'tf_efficientnetv2_s_in21ft1k': 'tf_efficientnetv2_s.in21k_ft_in1k',
  2546. 'tf_efficientnetv2_m_in21ft1k': 'tf_efficientnetv2_m.in21k_ft_in1k',
  2547. 'tf_efficientnetv2_l_in21ft1k': 'tf_efficientnetv2_l.in21k_ft_in1k',
  2548. 'tf_efficientnetv2_xl_in21ft1k': 'tf_efficientnetv2_xl.in21k_ft_in1k',
  2549. 'tf_efficientnetv2_s_in21k': 'tf_efficientnetv2_s.in21k',
  2550. 'tf_efficientnetv2_m_in21k': 'tf_efficientnetv2_m.in21k',
  2551. 'tf_efficientnetv2_l_in21k': 'tf_efficientnetv2_l.in21k',
  2552. 'tf_efficientnetv2_xl_in21k': 'tf_efficientnetv2_xl.in21k',
  2553. 'efficientnet_b2a': 'efficientnet_b2',
  2554. 'efficientnet_b3a': 'efficientnet_b3',
  2555. 'mnasnet_a1': 'semnasnet_100',
  2556. 'mnasnet_b1': 'mnasnet_100',
  2557. })