resnet.py 102 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265
  1. """PyTorch ResNet
  2. This started as a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with
  3. additional dropout and dynamic global avg/max pool.
  4. ResNeXt, SE-ResNeXt, SENet, and MXNet Gluon stem/downsample variants, tiered stems added by Ross Wightman
  5. Copyright 2019, Ross Wightman
  6. """
  7. import math
  8. from functools import partial
  9. from typing import Any, Dict, List, Optional, Tuple, Type, Union
  10. import torch
  11. import torch.nn as nn
  12. import torch.nn.functional as F
  13. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  14. from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, LayerType, create_attn, \
  15. get_attn, get_act_layer, get_norm_layer, create_classifier, create_aa, to_ntuple
  16. from ._builder import build_model_with_cfg
  17. from ._features import feature_take_indices
  18. from ._manipulate import checkpoint_seq
  19. from ._registry import register_model, generate_default_cfgs, register_model_deprecations
  20. __all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
  21. def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int:
  22. padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
  23. return padding
  24. class BasicBlock(nn.Module):
  25. """Basic residual block for ResNet.
  26. This is the standard residual block used in ResNet-18 and ResNet-34.
  27. """
  28. expansion = 1
  29. def __init__(
  30. self,
  31. inplanes: int,
  32. planes: int,
  33. stride: int = 1,
  34. downsample: Optional[nn.Module] = None,
  35. cardinality: int = 1,
  36. base_width: int = 64,
  37. reduce_first: int = 1,
  38. dilation: int = 1,
  39. first_dilation: Optional[int] = None,
  40. act_layer: Type[nn.Module] = nn.ReLU,
  41. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  42. attn_layer: Optional[Type[nn.Module]] = None,
  43. aa_layer: Optional[Type[nn.Module]] = None,
  44. drop_block: Optional[Type[nn.Module]] = None,
  45. drop_path: Optional[nn.Module] = None,
  46. device=None,
  47. dtype=None,
  48. ) -> None:
  49. """
  50. Args:
  51. inplanes: Input channel dimensionality.
  52. planes: Used to determine output channel dimensionalities.
  53. stride: Stride used in convolution layers.
  54. downsample: Optional downsample layer for residual path.
  55. cardinality: Number of convolution groups.
  56. base_width: Base width used to determine output channel dimensionality.
  57. reduce_first: Reduction factor for first convolution output width of residual blocks.
  58. dilation: Dilation rate for convolution layers.
  59. first_dilation: Dilation rate for first convolution layer.
  60. act_layer: Activation layer class.
  61. norm_layer: Normalization layer class.
  62. attn_layer: Attention layer class.
  63. aa_layer: Anti-aliasing layer class.
  64. drop_block: DropBlock layer class.
  65. drop_path: Optional DropPath layer instance.
  66. """
  67. dd = {'device': device, 'dtype': dtype}
  68. super().__init__()
  69. assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
  70. assert base_width == 64, 'BasicBlock does not support changing base width'
  71. first_planes = planes // reduce_first
  72. outplanes = planes * self.expansion
  73. first_dilation = first_dilation or dilation
  74. use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)
  75. self.conv1 = nn.Conv2d(
  76. inplanes,
  77. first_planes,
  78. kernel_size=3,
  79. stride=1 if use_aa else stride,
  80. padding=first_dilation,
  81. dilation=first_dilation,
  82. bias=False,
  83. **dd,
  84. )
  85. self.bn1 = norm_layer(first_planes, **dd)
  86. self.drop_block = drop_block() if drop_block is not None else nn.Identity()
  87. self.act1 = act_layer(inplace=True)
  88. self.aa = create_aa(aa_layer, channels=first_planes, stride=stride, enable=use_aa, **dd)
  89. self.conv2 = nn.Conv2d(
  90. first_planes,
  91. outplanes,
  92. kernel_size=3,
  93. padding=dilation,
  94. dilation=dilation,
  95. bias=False,
  96. **dd,
  97. )
  98. self.bn2 = norm_layer(outplanes, **dd)
  99. self.se = create_attn(attn_layer, outplanes, **dd)
  100. self.act2 = act_layer(inplace=True)
  101. self.downsample = downsample
  102. self.stride = stride
  103. self.dilation = dilation
  104. self.drop_path = drop_path
  105. def zero_init_last(self) -> None:
  106. """Initialize the last batch norm layer weights to zero for better convergence."""
  107. if getattr(self.bn2, 'weight', None) is not None:
  108. nn.init.zeros_(self.bn2.weight)
  109. def forward(self, x: torch.Tensor) -> torch.Tensor:
  110. shortcut = x
  111. x = self.conv1(x)
  112. x = self.bn1(x)
  113. x = self.drop_block(x)
  114. x = self.act1(x)
  115. x = self.aa(x)
  116. x = self.conv2(x)
  117. x = self.bn2(x)
  118. if self.se is not None:
  119. x = self.se(x)
  120. if self.drop_path is not None:
  121. x = self.drop_path(x)
  122. if self.downsample is not None:
  123. shortcut = self.downsample(shortcut)
  124. x += shortcut
  125. x = self.act2(x)
  126. return x
  127. class Bottleneck(nn.Module):
  128. """Bottleneck residual block for ResNet.
  129. This is the bottleneck block used in ResNet-50, ResNet-101, and ResNet-152.
  130. """
  131. expansion = 4
  132. def __init__(
  133. self,
  134. inplanes: int,
  135. planes: int,
  136. stride: int = 1,
  137. downsample: Optional[nn.Module] = None,
  138. cardinality: int = 1,
  139. base_width: int = 64,
  140. reduce_first: int = 1,
  141. dilation: int = 1,
  142. first_dilation: Optional[int] = None,
  143. act_layer: Type[nn.Module] = nn.ReLU,
  144. norm_layer: Type[nn.Module] = nn.BatchNorm2d,
  145. attn_layer: Optional[Type[nn.Module]] = None,
  146. aa_layer: Optional[Type[nn.Module]] = None,
  147. drop_block: Optional[Type[nn.Module]] = None,
  148. drop_path: Optional[nn.Module] = None,
  149. device=None,
  150. dtype=None,
  151. ) -> None:
  152. """
  153. Args:
  154. inplanes: Input channel dimensionality.
  155. planes: Used to determine output channel dimensionalities.
  156. stride: Stride used in convolution layers.
  157. downsample: Optional downsample layer for residual path.
  158. cardinality: Number of convolution groups.
  159. base_width: Base width used to determine output channel dimensionality.
  160. reduce_first: Reduction factor for first convolution output width of residual blocks.
  161. dilation: Dilation rate for convolution layers.
  162. first_dilation: Dilation rate for first convolution layer.
  163. act_layer: Activation layer class.
  164. norm_layer: Normalization layer class.
  165. attn_layer: Attention layer class.
  166. aa_layer: Anti-aliasing layer class.
  167. drop_block: DropBlock layer class.
  168. drop_path: Optional DropPath layer instance.
  169. """
  170. dd = {'device': device, 'dtype': dtype}
  171. super().__init__()
  172. width = int(math.floor(planes * (base_width / 64)) * cardinality)
  173. first_planes = width // reduce_first
  174. outplanes = planes * self.expansion
  175. first_dilation = first_dilation or dilation
  176. use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)
  177. self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False, **dd)
  178. self.bn1 = norm_layer(first_planes, **dd)
  179. self.act1 = act_layer(inplace=True)
  180. self.conv2 = nn.Conv2d(
  181. first_planes,
  182. width,
  183. kernel_size=3,
  184. stride=1 if use_aa else stride,
  185. padding=first_dilation,
  186. dilation=first_dilation,
  187. groups=cardinality,
  188. bias=False,
  189. **dd,
  190. )
  191. self.bn2 = norm_layer(width, **dd)
  192. self.drop_block = drop_block() if drop_block is not None else nn.Identity()
  193. self.act2 = act_layer(inplace=True)
  194. self.aa = create_aa(aa_layer, channels=width, stride=stride, enable=use_aa, **dd)
  195. self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False, **dd)
  196. self.bn3 = norm_layer(outplanes, **dd)
  197. self.se = create_attn(attn_layer, outplanes, **dd)
  198. self.act3 = act_layer(inplace=True)
  199. self.downsample = downsample
  200. self.stride = stride
  201. self.dilation = dilation
  202. self.drop_path = drop_path
  203. def zero_init_last(self) -> None:
  204. """Initialize the last batch norm layer weights to zero for better convergence."""
  205. if getattr(self.bn3, 'weight', None) is not None:
  206. nn.init.zeros_(self.bn3.weight)
  207. def forward(self, x: torch.Tensor) -> torch.Tensor:
  208. shortcut = x
  209. x = self.conv1(x)
  210. x = self.bn1(x)
  211. x = self.act1(x)
  212. x = self.conv2(x)
  213. x = self.bn2(x)
  214. x = self.drop_block(x)
  215. x = self.act2(x)
  216. x = self.aa(x)
  217. x = self.conv3(x)
  218. x = self.bn3(x)
  219. if self.se is not None:
  220. x = self.se(x)
  221. if self.drop_path is not None:
  222. x = self.drop_path(x)
  223. if self.downsample is not None:
  224. shortcut = self.downsample(shortcut)
  225. x += shortcut
  226. x = self.act3(x)
  227. return x
  228. def downsample_conv(
  229. in_channels: int,
  230. out_channels: int,
  231. kernel_size: int,
  232. stride: int = 1,
  233. dilation: int = 1,
  234. first_dilation: Optional[int] = None,
  235. norm_layer: Optional[Type[nn.Module]] = None,
  236. device=None,
  237. dtype=None,
  238. ) -> nn.Module:
  239. dd = {'device': device, 'dtype': dtype}
  240. norm_layer = norm_layer or nn.BatchNorm2d
  241. kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
  242. first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
  243. p = get_padding(kernel_size, stride, first_dilation)
  244. return nn.Sequential(*[
  245. nn.Conv2d(
  246. in_channels,
  247. out_channels,
  248. kernel_size,
  249. stride=stride,
  250. padding=p,
  251. dilation=first_dilation,
  252. bias=False,
  253. **dd
  254. ),
  255. norm_layer(out_channels, **dd)
  256. ])
  257. def downsample_avg(
  258. in_channels: int,
  259. out_channels: int,
  260. kernel_size: int,
  261. stride: int = 1,
  262. dilation: int = 1,
  263. first_dilation: Optional[int] = None,
  264. norm_layer: Optional[Type[nn.Module]] = None,
  265. device=None,
  266. dtype=None,
  267. ) -> nn.Module:
  268. dd = {'device': device, 'dtype': dtype}
  269. norm_layer = norm_layer or nn.BatchNorm2d
  270. avg_stride = stride if dilation == 1 else 1
  271. if stride == 1 and dilation == 1:
  272. pool = nn.Identity()
  273. else:
  274. avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
  275. pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
  276. return nn.Sequential(*[
  277. pool,
  278. nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False, **dd),
  279. norm_layer(out_channels, **dd)
  280. ])
  281. def drop_blocks(drop_prob: float = 0.) -> List[Optional[partial]]:
  282. """Create DropBlock layer instances for each stage.
  283. Args:
  284. drop_prob: Drop probability for DropBlock.
  285. Returns:
  286. List of DropBlock partial instances or None for each stage.
  287. """
  288. return [
  289. None, None,
  290. partial(DropBlock2d, drop_prob=drop_prob, block_size=5, gamma_scale=0.25) if drop_prob else None,
  291. partial(DropBlock2d, drop_prob=drop_prob, block_size=3, gamma_scale=1.00) if drop_prob else None]
  292. def make_blocks(
  293. block_fns: Tuple[Union[Type[BasicBlock], Type[Bottleneck]], ...],
  294. channels: Tuple[int, ...],
  295. block_repeats: Tuple[int, ...],
  296. inplanes: int,
  297. reduce_first: int = 1,
  298. output_stride: int = 32,
  299. down_kernel_size: int = 1,
  300. avg_down: bool = False,
  301. drop_block_rate: float = 0.,
  302. drop_path_rate: float = 0.,
  303. device=None,
  304. dtype=None,
  305. **kwargs,
  306. ) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Any]]]:
  307. """Create ResNet stages with specified block configurations.
  308. Args:
  309. block_fns: Block class to use for each stage.
  310. channels: Number of channels for each stage.
  311. block_repeats: Number of blocks to repeat for each stage.
  312. inplanes: Number of input channels.
  313. reduce_first: Reduction factor for first convolution in each stage.
  314. output_stride: Target output stride of network.
  315. down_kernel_size: Kernel size for downsample layers.
  316. avg_down: Use average pooling for downsample.
  317. drop_block_rate: DropBlock drop rate.
  318. drop_path_rate: Drop path rate for stochastic depth.
  319. **kwargs: Additional arguments passed to block constructors.
  320. Returns:
  321. Tuple of stage modules list and feature info list.
  322. """
  323. dd = {'device': device, 'dtype': dtype}
  324. stages = []
  325. feature_info = []
  326. net_num_blocks = sum(block_repeats)
  327. net_block_idx = 0
  328. net_stride = 4
  329. dilation = prev_dilation = 1
  330. for stage_idx, (block_fn, planes, num_blocks, db) in enumerate(zip(block_fns, channels, block_repeats, drop_blocks(drop_block_rate))):
  331. stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it
  332. stride = 1 if stage_idx == 0 else 2
  333. if net_stride >= output_stride:
  334. dilation *= stride
  335. stride = 1
  336. else:
  337. net_stride *= stride
  338. downsample = None
  339. if stride != 1 or inplanes != planes * block_fn.expansion:
  340. down_kwargs = dict(
  341. in_channels=inplanes,
  342. out_channels=planes * block_fn.expansion,
  343. kernel_size=down_kernel_size,
  344. stride=stride,
  345. dilation=dilation,
  346. first_dilation=prev_dilation,
  347. norm_layer=kwargs.get('norm_layer'),
  348. **dd,
  349. )
  350. downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs)
  351. block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs)
  352. blocks = []
  353. for block_idx in range(num_blocks):
  354. downsample = downsample if block_idx == 0 else None
  355. stride = stride if block_idx == 0 else 1
  356. block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule
  357. blocks.append(block_fn(
  358. inplanes,
  359. planes,
  360. stride,
  361. downsample,
  362. first_dilation=prev_dilation,
  363. drop_path=DropPath(block_dpr) if block_dpr > 0. else None,
  364. **block_kwargs,
  365. **dd,
  366. ))
  367. prev_dilation = dilation
  368. inplanes = planes * block_fn.expansion
  369. net_block_idx += 1
  370. stages.append((stage_name, nn.Sequential(*blocks)))
  371. feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name))
  372. return stages, feature_info
  373. class ResNet(nn.Module):
  374. """ResNet / ResNeXt / SE-ResNeXt / SE-Net
  375. This class implements all variants of ResNet, ResNeXt, SE-ResNeXt, and SENet that
  376. * have > 1 stride in the 3x3 conv layer of bottleneck
  377. * have conv-bn-act ordering
  378. This ResNet impl supports a number of stem and downsample options based on the v1c, v1d, v1e, and v1s
  379. variants included in the MXNet Gluon ResNetV1b model. The C and D variants are also discussed in the
  380. 'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187. The B variant is equivalent to torchvision default.
  381. ResNet variants (the same modifications can be used in SE/ResNeXt models as well):
  382. * normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b
  383. * c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64)
  384. * d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample
  385. * e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample
  386. * s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128)
  387. * t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample
  388. * tn - 3 layer deep 3x3 stem, stem width = 32 (24, 32, 64), average pool in downsample
  389. ResNeXt
  390. * normal - 7x7 stem, stem_width = 64, standard cardinality and base widths
  391. * same c,d, e, s variants as ResNet can be enabled
  392. SE-ResNeXt
  393. * normal - 7x7 stem, stem_width = 64
  394. * same c, d, e, s variants as ResNet can be enabled
  395. SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64,
  396. reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block
  397. """
  398. def __init__(
  399. self,
  400. block: Union[BasicBlock, Bottleneck],
  401. layers: Tuple[int, ...],
  402. num_classes: int = 1000,
  403. in_chans: int = 3,
  404. output_stride: int = 32,
  405. global_pool: str = 'avg',
  406. cardinality: int = 1,
  407. base_width: int = 64,
  408. stem_width: int = 64,
  409. stem_type: str = '',
  410. replace_stem_pool: bool = False,
  411. block_reduce_first: int = 1,
  412. down_kernel_size: int = 1,
  413. avg_down: bool = False,
  414. channels: Optional[Tuple[int, ...]] = (64, 128, 256, 512),
  415. act_layer: LayerType = nn.ReLU,
  416. norm_layer: LayerType = nn.BatchNorm2d,
  417. aa_layer: Optional[Type[nn.Module]] = None,
  418. drop_rate: float = 0.0,
  419. drop_path_rate: float = 0.,
  420. drop_block_rate: float = 0.,
  421. zero_init_last: bool = True,
  422. block_args: Optional[Dict[str, Any]] = None,
  423. device=None,
  424. dtype=None,
  425. ):
  426. """
  427. Args:
  428. block (nn.Module): class for the residual block. Options are BasicBlock, Bottleneck.
  429. layers (List[int]) : number of layers in each block
  430. num_classes (int): number of classification classes (default 1000)
  431. in_chans (int): number of input (color) channels. (default 3)
  432. output_stride (int): output stride of the network, 32, 16, or 8. (default 32)
  433. global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg')
  434. cardinality (int): number of convolution groups for 3x3 conv in Bottleneck. (default 1)
  435. base_width (int): bottleneck channels factor. `planes * base_width / 64 * cardinality` (default 64)
  436. stem_width (int): number of channels in stem convolutions (default 64)
  437. stem_type (str): The type of stem (default ''):
  438. * '', default - a single 7x7 conv with a width of stem_width
  439. * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
  440. * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2
  441. replace_stem_pool (bool): replace stem max-pooling layer with a 3x3 stride-2 convolution
  442. block_reduce_first (int): Reduction factor for first convolution output width of residual blocks,
  443. 1 for all archs except senets, where 2 (default 1)
  444. down_kernel_size (int): kernel size of residual block downsample path,
  445. 1x1 for most, 3x3 for senets (default: 1)
  446. avg_down (bool): use avg pooling for projection skip connection between stages/downsample (default False)
  447. act_layer (str, nn.Module): activation layer
  448. norm_layer (str, nn.Module): normalization layer
  449. aa_layer (nn.Module): anti-aliasing layer
  450. drop_rate (float): Dropout probability before classifier, for training (default 0.)
  451. drop_path_rate (float): Stochastic depth drop-path rate (default 0.)
  452. drop_block_rate (float): Drop block rate (default 0.)
  453. zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight)
  454. block_args (dict): Extra kwargs to pass through to block module
  455. """
  456. super().__init__()
  457. dd = {'device': device, 'dtype': dtype}
  458. block_args = block_args or dict()
  459. assert output_stride in (8, 16, 32)
  460. self.num_classes = num_classes
  461. self.drop_rate = drop_rate
  462. self.grad_checkpointing = False
  463. act_layer = get_act_layer(act_layer)
  464. norm_layer = get_norm_layer(norm_layer)
  465. # Stem
  466. deep_stem = 'deep' in stem_type
  467. inplanes = stem_width * 2 if deep_stem else 64
  468. if deep_stem:
  469. stem_chs = (stem_width, stem_width)
  470. if 'tiered' in stem_type:
  471. stem_chs = (3 * (stem_width // 4), stem_width)
  472. self.conv1 = nn.Sequential(*[
  473. nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False, **dd),
  474. norm_layer(stem_chs[0], **dd),
  475. act_layer(inplace=True),
  476. nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False, **dd),
  477. norm_layer(stem_chs[1], **dd),
  478. act_layer(inplace=True),
  479. nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False, **dd)])
  480. else:
  481. self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False, **dd)
  482. self.bn1 = norm_layer(inplanes, **dd)
  483. self.act1 = act_layer(inplace=True)
  484. self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
  485. # Stem pooling. The name 'maxpool' remains for weight compatibility.
  486. if replace_stem_pool:
  487. self.maxpool = nn.Sequential(*filter(None, [
  488. nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False, **dd),
  489. create_aa(aa_layer, channels=inplanes, stride=2, **dd) if aa_layer is not None else None,
  490. norm_layer(inplanes, **dd),
  491. act_layer(inplace=True),
  492. ]))
  493. else:
  494. if aa_layer is not None:
  495. if issubclass(aa_layer, nn.AvgPool2d):
  496. self.maxpool = aa_layer(2)
  497. else:
  498. self.maxpool = nn.Sequential(*[
  499. nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
  500. aa_layer(channels=inplanes, stride=2, **dd)])
  501. else:
  502. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  503. # Feature Blocks
  504. block_fns = to_ntuple(len(channels))(block)
  505. stage_modules, stage_feature_info = make_blocks(
  506. block_fns,
  507. channels,
  508. layers,
  509. inplanes,
  510. cardinality=cardinality,
  511. base_width=base_width,
  512. output_stride=output_stride,
  513. reduce_first=block_reduce_first,
  514. avg_down=avg_down,
  515. down_kernel_size=down_kernel_size,
  516. act_layer=act_layer,
  517. norm_layer=norm_layer,
  518. aa_layer=aa_layer,
  519. drop_block_rate=drop_block_rate,
  520. drop_path_rate=drop_path_rate,
  521. **block_args,
  522. **dd,
  523. )
  524. for stage in stage_modules:
  525. self.add_module(*stage) # layer1, layer2, etc
  526. self.feature_info.extend(stage_feature_info)
  527. # Head (Pooling and Classifier)
  528. self.num_features = self.head_hidden_size = channels[-1] * block_fns[-1].expansion
  529. self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool, **dd)
  530. self.init_weights(zero_init_last=zero_init_last)
  531. @torch.jit.ignore
  532. def init_weights(self, zero_init_last: bool = True) -> None:
  533. """Initialize model weights.
  534. Args:
  535. zero_init_last: Zero-initialize the last BN in each residual branch.
  536. """
  537. for n, m in self.named_modules():
  538. if isinstance(m, nn.Conv2d):
  539. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  540. if zero_init_last:
  541. for m in self.modules():
  542. if hasattr(m, 'zero_init_last'):
  543. m.zero_init_last()
  544. @torch.jit.ignore
  545. def group_matcher(self, coarse: bool = False) -> Dict[str, str]:
  546. """Create regex patterns for parameter grouping.
  547. Args:
  548. coarse: Use coarse (stage-level) or fine (block-level) grouping.
  549. Returns:
  550. Dictionary mapping group names to regex patterns.
  551. """
  552. matcher = dict(stem=r'^conv1|bn1|maxpool', blocks=r'^layer(\d+)' if coarse else r'^layer(\d+)\.(\d+)')
  553. return matcher
  554. @torch.jit.ignore
  555. def set_grad_checkpointing(self, enable: bool = True) -> None:
  556. """Enable or disable gradient checkpointing.
  557. Args:
  558. enable: Whether to enable gradient checkpointing.
  559. """
  560. self.grad_checkpointing = enable
  561. @torch.jit.ignore
  562. def get_classifier(self, name_only: bool = False) -> Union[str, nn.Module]:
  563. """Get the classifier module.
  564. Args:
  565. name_only: Return classifier module name instead of module.
  566. Returns:
  567. Classifier module or name.
  568. """
  569. return 'fc' if name_only else self.fc
  570. def reset_classifier(self, num_classes: int, global_pool: str = 'avg') -> None:
  571. """Reset the classifier head.
  572. Args:
  573. num_classes: Number of classes for new classifier.
  574. global_pool: Global pooling type.
  575. """
  576. self.num_classes = num_classes
  577. self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
  578. def forward_intermediates(
  579. self,
  580. x: torch.Tensor,
  581. indices: Optional[Union[int, List[int]]] = None,
  582. norm: bool = False,
  583. stop_early: bool = False,
  584. output_fmt: str = 'NCHW',
  585. intermediates_only: bool = False,
  586. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  587. """Forward features that returns intermediates.
  588. Args:
  589. x: Input image tensor.
  590. indices: Take last n blocks if int, all if None, select matching indices if sequence.
  591. norm: Apply norm layer to compatible intermediates.
  592. stop_early: Stop iterating over blocks when last desired intermediate hit.
  593. output_fmt: Shape of intermediate feature outputs.
  594. intermediates_only: Only return intermediate features.
  595. Returns:
  596. Features and list of intermediate features or just intermediate features.
  597. """
  598. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  599. intermediates = []
  600. take_indices, max_index = feature_take_indices(5, indices)
  601. # forward pass
  602. feat_idx = 0
  603. x = self.conv1(x)
  604. x = self.bn1(x)
  605. x = self.act1(x)
  606. if feat_idx in take_indices:
  607. intermediates.append(x)
  608. x = self.maxpool(x)
  609. layer_names = ('layer1', 'layer2', 'layer3', 'layer4')
  610. if stop_early:
  611. layer_names = layer_names[:max_index]
  612. for n in layer_names:
  613. feat_idx += 1
  614. x = getattr(self, n)(x) # won't work with torchscript, but keeps code reasonable, FML
  615. if feat_idx in take_indices:
  616. intermediates.append(x)
  617. if intermediates_only:
  618. return intermediates
  619. return x, intermediates
  620. def prune_intermediate_layers(
  621. self,
  622. indices: Union[int, List[int]] = 1,
  623. prune_norm: bool = False,
  624. prune_head: bool = True,
  625. ) -> List[int]:
  626. """Prune layers not required for specified intermediates.
  627. Args:
  628. indices: Indices of intermediate layers to keep.
  629. prune_norm: Whether to prune normalization layers.
  630. prune_head: Whether to prune the classifier head.
  631. Returns:
  632. List of indices that were kept.
  633. """
  634. take_indices, max_index = feature_take_indices(5, indices)
  635. layer_names = ('layer1', 'layer2', 'layer3', 'layer4')
  636. layer_names = layer_names[max_index:]
  637. for n in layer_names:
  638. setattr(self, n, nn.Identity())
  639. if prune_head:
  640. self.reset_classifier(0, '')
  641. return take_indices
  642. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  643. """Forward pass through feature extraction layers."""
  644. x = self.conv1(x)
  645. x = self.bn1(x)
  646. x = self.act1(x)
  647. x = self.maxpool(x)
  648. if self.grad_checkpointing and not torch.jit.is_scripting():
  649. x = checkpoint_seq([self.layer1, self.layer2, self.layer3, self.layer4], x, flatten=True)
  650. else:
  651. x = self.layer1(x)
  652. x = self.layer2(x)
  653. x = self.layer3(x)
  654. x = self.layer4(x)
  655. return x
  656. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  657. """Forward pass through classifier head.
  658. Args:
  659. x: Feature tensor.
  660. pre_logits: Return features before final classifier layer.
  661. Returns:
  662. Output tensor.
  663. """
  664. x = self.global_pool(x)
  665. if self.drop_rate:
  666. x = F.dropout(x, p=float(self.drop_rate), training=self.training)
  667. return x if pre_logits else self.fc(x)
  668. def forward(self, x: torch.Tensor) -> torch.Tensor:
  669. """Forward pass."""
  670. x = self.forward_features(x)
  671. x = self.forward_head(x)
  672. return x
  673. def _create_resnet(variant: str, pretrained: bool = False, **kwargs) -> ResNet:
  674. """Create a ResNet model.
  675. Args:
  676. variant: Model variant name.
  677. pretrained: Load pretrained weights.
  678. **kwargs: Additional model arguments.
  679. Returns:
  680. ResNet model instance.
  681. """
  682. return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
  683. def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  684. """Create a default configuration for ResNet models."""
  685. return {
  686. 'url': url,
  687. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  688. 'crop_pct': 0.875, 'interpolation': 'bilinear',
  689. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  690. 'first_conv': 'conv1', 'classifier': 'fc',
  691. 'license': 'apache-2.0',
  692. **kwargs
  693. }
  694. def _tcfg(url: str = '', **kwargs) -> Dict[str, Any]:
  695. """Create a configuration with bicubic interpolation."""
  696. return _cfg(url=url, **dict({'interpolation': 'bicubic'}, **kwargs))
  697. def _ttcfg(url: str = '', **kwargs) -> Dict[str, Any]:
  698. """Create a configuration for models trained with timm."""
  699. return _cfg(url=url, **dict({
  700. 'interpolation': 'bicubic', 'test_input_size': (3, 288, 288), 'test_crop_pct': 0.95,
  701. 'origin_url': 'https://github.com/huggingface/pytorch-image-models',
  702. }, **kwargs))
  703. def _rcfg(url: str = '', **kwargs) -> Dict[str, Any]:
  704. """Create a configuration for ResNet-RS models."""
  705. return _cfg(url=url, **dict({
  706. 'interpolation': 'bicubic', 'crop_pct': 0.95, 'test_input_size': (3, 288, 288), 'test_crop_pct': 1.0,
  707. 'origin_url': 'https://github.com/huggingface/pytorch-image-models', 'paper_ids': 'arXiv:2110.00476'
  708. }, **kwargs))
  709. def _r3cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  710. """Create a configuration for ResNet-RS models with 160x160 input."""
  711. return _cfg(url=url, **dict({
  712. 'interpolation': 'bicubic', 'input_size': (3, 160, 160), 'pool_size': (5, 5),
  713. 'crop_pct': 0.95, 'test_input_size': (3, 224, 224), 'test_crop_pct': 0.95,
  714. 'origin_url': 'https://github.com/huggingface/pytorch-image-models', 'paper_ids': 'arXiv:2110.00476',
  715. }, **kwargs))
  716. def _gcfg(url: str = '', **kwargs) -> Dict[str, Any]:
  717. """Create a configuration for Gluon pretrained models."""
  718. return _cfg(url=url, **dict({
  719. 'interpolation': 'bicubic',
  720. 'origin_url': 'https://cv.gluon.ai/model_zoo/classification.html',
  721. }, **kwargs))
  722. default_cfgs = generate_default_cfgs({
  723. # ResNet and Wide ResNet trained w/ timm (RSB paper and others)
  724. 'resnet10t.c3_in1k': _ttcfg(
  725. hf_hub_id='timm/',
  726. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet10t_176_c3-f3215ab1.pth',
  727. input_size=(3, 176, 176), pool_size=(6, 6), test_crop_pct=0.95, test_input_size=(3, 224, 224),
  728. first_conv='conv1.0'),
  729. 'resnet14t.c3_in1k': _ttcfg(
  730. hf_hub_id='timm/',
  731. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet14t_176_c3-c4ed2c37.pth',
  732. input_size=(3, 176, 176), pool_size=(6, 6), test_crop_pct=0.95, test_input_size=(3, 224, 224),
  733. first_conv='conv1.0'),
  734. 'resnet18.a1_in1k': _rcfg(
  735. hf_hub_id='timm/',
  736. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet18_a1_0-d63eafa0.pth'),
  737. 'resnet18.a2_in1k': _rcfg(
  738. hf_hub_id='timm/',
  739. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet18_a2_0-b61bd467.pth'),
  740. 'resnet18.a3_in1k': _r3cfg(
  741. hf_hub_id='timm/',
  742. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet18_a3_0-40c531c8.pth'),
  743. 'resnet18d.ra2_in1k': _ttcfg(
  744. hf_hub_id='timm/',
  745. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet18d_ra2-48a79e06.pth',
  746. first_conv='conv1.0'),
  747. 'resnet18d.ra4_e3600_r224_in1k': _rcfg(
  748. hf_hub_id='timm/',
  749. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='conv1.0'),
  750. 'resnet34.a1_in1k': _rcfg(
  751. hf_hub_id='timm/',
  752. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet34_a1_0-46f8f793.pth'),
  753. 'resnet34.a2_in1k': _rcfg(
  754. hf_hub_id='timm/',
  755. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet34_a2_0-82d47d71.pth'),
  756. 'resnet34.a3_in1k': _r3cfg(
  757. hf_hub_id='timm/',
  758. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet34_a3_0-a20cabb6.pth',
  759. crop_pct=0.95),
  760. 'resnet34.bt_in1k': _ttcfg(
  761. hf_hub_id='timm/',
  762. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'),
  763. 'resnet34.ra4_e3600_r224_in1k': _rcfg(
  764. hf_hub_id='timm/',
  765. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9),
  766. 'resnet34d.ra2_in1k': _ttcfg(
  767. hf_hub_id='timm/',
  768. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34d_ra2-f8dcfcaf.pth',
  769. first_conv='conv1.0'),
  770. 'resnet26.bt_in1k': _ttcfg(
  771. hf_hub_id='timm/',
  772. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26-9aa10e23.pth'),
  773. 'resnet26d.bt_in1k': _ttcfg(
  774. hf_hub_id='timm/',
  775. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26d-69e92c46.pth',
  776. first_conv='conv1.0'),
  777. 'resnet26t.ra2_in1k': _ttcfg(
  778. hf_hub_id='timm/',
  779. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet26t_256_ra2-6f6fa748.pth',
  780. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
  781. crop_pct=0.94, test_input_size=(3, 320, 320), test_crop_pct=1.0),
  782. 'resnet50.a1_in1k': _rcfg(
  783. hf_hub_id='timm/',
  784. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth'),
  785. 'resnet50.a1h_in1k': _rcfg(
  786. hf_hub_id='timm/',
  787. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1h2_176-001a1197.pth',
  788. input_size=(3, 176, 176), pool_size=(6, 6), crop_pct=0.9, test_input_size=(3, 224, 224), test_crop_pct=1.0),
  789. 'resnet50.a2_in1k': _rcfg(
  790. hf_hub_id='timm/',
  791. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a2_0-a2746f79.pth'),
  792. 'resnet50.a3_in1k': _r3cfg(
  793. hf_hub_id='timm/',
  794. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a3_0-59cae1ef.pth'),
  795. 'resnet50.b1k_in1k': _rcfg(
  796. hf_hub_id='timm/',
  797. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_b1k-532a802a.pth'),
  798. 'resnet50.b2k_in1k': _rcfg(
  799. hf_hub_id='timm/',
  800. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_b2k-1ba180c1.pth'),
  801. 'resnet50.c1_in1k': _rcfg(
  802. hf_hub_id='timm/',
  803. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_c1-5ba5e060.pth'),
  804. 'resnet50.c2_in1k': _rcfg(
  805. hf_hub_id='timm/',
  806. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_c2-d01e05b2.pth'),
  807. 'resnet50.d_in1k': _rcfg(
  808. hf_hub_id='timm/',
  809. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_d-f39db8af.pth'),
  810. 'resnet50.ram_in1k': _ttcfg(
  811. hf_hub_id='timm/',
  812. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth'),
  813. 'resnet50.am_in1k': _tcfg(
  814. hf_hub_id='timm/',
  815. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/resnet50_am-6c502b37.pth'),
  816. 'resnet50.ra_in1k': _ttcfg(
  817. hf_hub_id='timm/',
  818. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/resnet50_ra-85ebb6e5.pth'),
  819. 'resnet50.bt_in1k': _ttcfg(
  820. hf_hub_id='timm/',
  821. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth'),
  822. 'resnet50d.ra2_in1k': _ttcfg(
  823. hf_hub_id='timm/',
  824. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth',
  825. first_conv='conv1.0'),
  826. 'resnet50d.ra4_e3600_r224_in1k': _rcfg(
  827. hf_hub_id='timm/',
  828. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
  829. crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0,
  830. first_conv='conv1.0'),
  831. 'resnet50d.a1_in1k': _rcfg(
  832. hf_hub_id='timm/',
  833. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50d_a1_0-e20cff14.pth',
  834. first_conv='conv1.0'),
  835. 'resnet50d.a2_in1k': _rcfg(
  836. hf_hub_id='timm/',
  837. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50d_a2_0-a3adc64d.pth',
  838. first_conv='conv1.0'),
  839. 'resnet50d.a3_in1k': _r3cfg(
  840. hf_hub_id='timm/',
  841. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50d_a3_0-403fdfad.pth',
  842. first_conv='conv1.0'),
  843. 'resnet50t.untrained': _ttcfg(first_conv='conv1.0'),
  844. 'resnet101.a1h_in1k': _rcfg(
  845. hf_hub_id='timm/',
  846. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1h-36d3f2aa.pth'),
  847. 'resnet101.a1_in1k': _rcfg(
  848. hf_hub_id='timm/',
  849. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1_0-cdcb52a9.pth'),
  850. 'resnet101.a2_in1k': _rcfg(
  851. hf_hub_id='timm/',
  852. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a2_0-6edb36c7.pth'),
  853. 'resnet101.a3_in1k': _r3cfg(
  854. hf_hub_id='timm/',
  855. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a3_0-1db14157.pth'),
  856. 'resnet101d.ra2_in1k': _ttcfg(
  857. hf_hub_id='timm/',
  858. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet101d_ra2-2803ffab.pth',
  859. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95,
  860. test_crop_pct=1.0, test_input_size=(3, 320, 320)),
  861. 'resnet152.a1h_in1k': _rcfg(
  862. hf_hub_id='timm/',
  863. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet152_a1h-dc400468.pth'),
  864. 'resnet152.a1_in1k': _rcfg(
  865. hf_hub_id='timm/',
  866. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet152_a1_0-2eee8a7a.pth'),
  867. 'resnet152.a2_in1k': _rcfg(
  868. hf_hub_id='timm/',
  869. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet152_a2_0-b4c6978f.pth'),
  870. 'resnet152.a3_in1k': _r3cfg(
  871. hf_hub_id='timm/',
  872. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet152_a3_0-134d4688.pth'),
  873. 'resnet152d.ra2_in1k': _ttcfg(
  874. hf_hub_id='timm/',
  875. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet152d_ra2-5cac0439.pth',
  876. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95,
  877. test_crop_pct=1.0, test_input_size=(3, 320, 320)),
  878. 'resnet200.untrained': _ttcfg(),
  879. 'resnet200d.ra2_in1k': _ttcfg(
  880. hf_hub_id='timm/',
  881. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet200d_ra2-bdba9bf9.pth',
  882. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95,
  883. test_crop_pct=1.0, test_input_size=(3, 320, 320)),
  884. 'wide_resnet50_2.racm_in1k': _ttcfg(
  885. hf_hub_id='timm/',
  886. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/wide_resnet50_racm-8234f177.pth'),
  887. # torchvision resnet weights
  888. 'resnet18.tv_in1k': _cfg(
  889. hf_hub_id='timm/',
  890. url='https://download.pytorch.org/models/resnet18-f37072fd.pth',
  891. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  892. 'resnet34.tv_in1k': _cfg(
  893. hf_hub_id='timm/',
  894. url='https://download.pytorch.org/models/resnet34-b627a593.pth',
  895. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  896. 'resnet50.tv_in1k': _cfg(
  897. hf_hub_id='timm/',
  898. url='https://download.pytorch.org/models/resnet50-0676ba61.pth',
  899. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  900. 'resnet50.tv2_in1k': _cfg(
  901. hf_hub_id='timm/',
  902. url='https://download.pytorch.org/models/resnet50-11ad3fa6.pth',
  903. input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965,
  904. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  905. 'resnet101.tv_in1k': _cfg(
  906. hf_hub_id='timm/',
  907. url='https://download.pytorch.org/models/resnet101-63fe2227.pth',
  908. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  909. 'resnet101.tv2_in1k': _cfg(
  910. hf_hub_id='timm/',
  911. url='https://download.pytorch.org/models/resnet101-cd907fc2.pth',
  912. input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965,
  913. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  914. 'resnet152.tv_in1k': _cfg(
  915. hf_hub_id='timm/',
  916. url='https://download.pytorch.org/models/resnet152-394f9c45.pth',
  917. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  918. 'resnet152.tv2_in1k': _cfg(
  919. hf_hub_id='timm/',
  920. url='https://download.pytorch.org/models/resnet152-f82ba261.pth',
  921. input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965,
  922. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  923. 'wide_resnet50_2.tv_in1k': _cfg(
  924. hf_hub_id='timm/',
  925. url='https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
  926. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  927. 'wide_resnet50_2.tv2_in1k': _cfg(
  928. hf_hub_id='timm/',
  929. url='https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth',
  930. input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965,
  931. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  932. 'wide_resnet101_2.tv_in1k': _cfg(
  933. hf_hub_id='timm/',
  934. url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
  935. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  936. 'wide_resnet101_2.tv2_in1k': _cfg(
  937. hf_hub_id='timm/',
  938. url='https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth',
  939. input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965,
  940. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  941. # ResNets w/ alternative norm layers
  942. 'resnet50_gn.a1h_in1k': _ttcfg(
  943. hf_hub_id='timm/',
  944. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_gn_a1h2-8fe6c4d0.pth',
  945. crop_pct=0.94),
  946. # ResNeXt trained in timm (RSB paper and others)
  947. 'resnext50_32x4d.a1h_in1k': _rcfg(
  948. hf_hub_id='timm/',
  949. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a1h-0146ab0a.pth'),
  950. 'resnext50_32x4d.a1_in1k': _rcfg(
  951. hf_hub_id='timm/',
  952. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a1_0-b5a91a1d.pth'),
  953. 'resnext50_32x4d.a2_in1k': _rcfg(
  954. hf_hub_id='timm/',
  955. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a2_0-efc76add.pth'),
  956. 'resnext50_32x4d.a3_in1k': _r3cfg(
  957. hf_hub_id='timm/',
  958. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnext50_32x4d_a3_0-3e450271.pth'),
  959. 'resnext50_32x4d.ra_in1k': _ttcfg(
  960. hf_hub_id='timm/',
  961. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d_ra-d733960d.pth'),
  962. 'resnext50d_32x4d.bt_in1k': _ttcfg(
  963. hf_hub_id='timm/',
  964. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50d_32x4d-103e99f8.pth',
  965. first_conv='conv1.0'),
  966. 'resnext101_32x4d.untrained': _ttcfg(),
  967. 'resnext101_64x4d.c1_in1k': _rcfg(
  968. hf_hub_id='timm/',
  969. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnext101_64x4d_c-0d0e0cc0.pth'),
  970. # torchvision ResNeXt weights
  971. 'resnext50_32x4d.tv_in1k': _cfg(
  972. hf_hub_id='timm/',
  973. url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
  974. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  975. 'resnext101_32x8d.tv_in1k': _cfg(
  976. hf_hub_id='timm/',
  977. url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
  978. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  979. 'resnext101_64x4d.tv_in1k': _cfg(
  980. hf_hub_id='timm/',
  981. url='https://download.pytorch.org/models/resnext101_64x4d-173b62eb.pth',
  982. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  983. 'resnext50_32x4d.tv2_in1k': _cfg(
  984. hf_hub_id='timm/',
  985. url='https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth',
  986. input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965,
  987. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  988. 'resnext101_32x8d.tv2_in1k': _cfg(
  989. hf_hub_id='timm/',
  990. url='https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth',
  991. input_size=(3, 176, 176), pool_size=(6, 6), test_input_size=(3, 224, 224), test_crop_pct=0.965,
  992. license='bsd-3-clause', origin_url='https://github.com/pytorch/vision'),
  993. # ResNeXt models - Weakly Supervised Pretraining on Instagram Hashtags
  994. # from https://github.com/facebookresearch/WSL-Images
  995. # Please note the CC-BY-NC 4.0 license on these weights, non-commercial use only.
  996. 'resnext101_32x8d.fb_wsl_ig1b_ft_in1k': _cfg(
  997. hf_hub_id='timm/',
  998. url='https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth',
  999. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/WSL-Images'),
  1000. 'resnext101_32x16d.fb_wsl_ig1b_ft_in1k': _cfg(
  1001. hf_hub_id='timm/',
  1002. url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth',
  1003. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/WSL-Images'),
  1004. 'resnext101_32x32d.fb_wsl_ig1b_ft_in1k': _cfg(
  1005. hf_hub_id='timm/',
  1006. url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth',
  1007. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/WSL-Images'),
  1008. 'resnext101_32x48d.fb_wsl_ig1b_ft_in1k': _cfg(
  1009. hf_hub_id='timm/',
  1010. url='https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth',
  1011. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/WSL-Images'),
  1012. # Semi-Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models
  1013. # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only.
  1014. 'resnet18.fb_ssl_yfcc100m_ft_in1k': _cfg(
  1015. hf_hub_id='timm/',
  1016. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth',
  1017. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1018. 'resnet50.fb_ssl_yfcc100m_ft_in1k': _cfg(
  1019. hf_hub_id='timm/',
  1020. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet50-08389792.pth',
  1021. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1022. 'resnext50_32x4d.fb_ssl_yfcc100m_ft_in1k': _cfg(
  1023. hf_hub_id='timm/',
  1024. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext50_32x4-ddb3e555.pth',
  1025. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1026. 'resnext101_32x4d.fb_ssl_yfcc100m_ft_in1k': _cfg(
  1027. hf_hub_id='timm/',
  1028. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x4-dc43570a.pth',
  1029. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1030. 'resnext101_32x8d.fb_ssl_yfcc100m_ft_in1k': _cfg(
  1031. hf_hub_id='timm/',
  1032. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x8-2cfe2f8b.pth',
  1033. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1034. 'resnext101_32x16d.fb_ssl_yfcc100m_ft_in1k': _cfg(
  1035. hf_hub_id='timm/',
  1036. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnext101_32x16-15fffa57.pth',
  1037. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1038. # Semi-Weakly Supervised ResNe*t models from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models
  1039. # Please note the CC-BY-NC 4.0 license on theses weights, non-commercial use only.
  1040. 'resnet18.fb_swsl_ig1b_ft_in1k': _cfg(
  1041. hf_hub_id='timm/',
  1042. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth',
  1043. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1044. 'resnet50.fb_swsl_ig1b_ft_in1k': _cfg(
  1045. hf_hub_id='timm/',
  1046. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet50-16a12f1b.pth',
  1047. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1048. 'resnext50_32x4d.fb_swsl_ig1b_ft_in1k': _cfg(
  1049. hf_hub_id='timm/',
  1050. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext50_32x4-72679e44.pth',
  1051. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1052. 'resnext101_32x4d.fb_swsl_ig1b_ft_in1k': _cfg(
  1053. hf_hub_id='timm/',
  1054. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x4-3f87e46b.pth',
  1055. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1056. 'resnext101_32x8d.fb_swsl_ig1b_ft_in1k': _cfg(
  1057. hf_hub_id='timm/',
  1058. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth',
  1059. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1060. 'resnext101_32x16d.fb_swsl_ig1b_ft_in1k': _cfg(
  1061. hf_hub_id='timm/',
  1062. url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth',
  1063. license='cc-by-nc-4.0', origin_url='https://github.com/facebookresearch/semi-supervised-ImageNet1K-models'),
  1064. # Efficient Channel Attention ResNets
  1065. 'ecaresnet26t.ra2_in1k': _ttcfg(
  1066. hf_hub_id='timm/',
  1067. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet26t_ra2-46609757.pth',
  1068. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
  1069. test_crop_pct=0.95, test_input_size=(3, 320, 320)),
  1070. 'ecaresnetlight.miil_in1k': _tcfg(
  1071. hf_hub_id='timm/',
  1072. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/ecaresnetlight-75a9c627.pth',
  1073. test_crop_pct=0.95, test_input_size=(3, 288, 288)),
  1074. 'ecaresnet50d.miil_in1k': _tcfg(
  1075. hf_hub_id='timm/',
  1076. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/ecaresnet50d-93c81e3b.pth',
  1077. first_conv='conv1.0', test_crop_pct=0.95, test_input_size=(3, 288, 288)),
  1078. 'ecaresnet50d_pruned.miil_in1k': _tcfg(
  1079. hf_hub_id='timm/',
  1080. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/ecaresnet50d_p-e4fa23c2.pth',
  1081. first_conv='conv1.0', test_crop_pct=0.95, test_input_size=(3, 288, 288)),
  1082. 'ecaresnet50t.ra2_in1k': _tcfg(
  1083. hf_hub_id='timm/',
  1084. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet50t_ra2-f7ac63c4.pth',
  1085. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
  1086. test_crop_pct=0.95, test_input_size=(3, 320, 320)),
  1087. 'ecaresnet50t.a1_in1k': _rcfg(
  1088. hf_hub_id='timm/',
  1089. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/ecaresnet50t_a1_0-99bd76a8.pth',
  1090. first_conv='conv1.0'),
  1091. 'ecaresnet50t.a2_in1k': _rcfg(
  1092. hf_hub_id='timm/',
  1093. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/ecaresnet50t_a2_0-b1c7b745.pth',
  1094. first_conv='conv1.0'),
  1095. 'ecaresnet50t.a3_in1k': _r3cfg(
  1096. hf_hub_id='timm/',
  1097. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/ecaresnet50t_a3_0-8cc311f1.pth',
  1098. first_conv='conv1.0'),
  1099. 'ecaresnet101d.miil_in1k': _tcfg(
  1100. hf_hub_id='timm/',
  1101. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/ecaresnet101d-153dad65.pth',
  1102. first_conv='conv1.0', test_crop_pct=0.95, test_input_size=(3, 288, 288)),
  1103. 'ecaresnet101d_pruned.miil_in1k': _tcfg(
  1104. hf_hub_id='timm/',
  1105. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/ecaresnet101d_p-9e74cb91.pth',
  1106. first_conv='conv1.0', test_crop_pct=0.95, test_input_size=(3, 288, 288)),
  1107. 'ecaresnet200d.untrained': _ttcfg(
  1108. first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.95, pool_size=(8, 8)),
  1109. 'ecaresnet269d.ra2_in1k': _ttcfg(
  1110. hf_hub_id='timm/',
  1111. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet269d_320_ra2-7baa55cb.pth',
  1112. first_conv='conv1.0', input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=0.95,
  1113. test_crop_pct=1.0, test_input_size=(3, 352, 352)),
  1114. # Efficient Channel Attention ResNeXts
  1115. 'ecaresnext26t_32x4d.untrained': _tcfg(first_conv='conv1.0'),
  1116. 'ecaresnext50t_32x4d.untrained': _tcfg(first_conv='conv1.0'),
  1117. # Squeeze-Excitation ResNets, to eventually replace the models in senet.py
  1118. 'seresnet18.untrained': _ttcfg(),
  1119. 'seresnet34.untrained': _ttcfg(),
  1120. 'seresnet50.a1_in1k': _rcfg(
  1121. hf_hub_id='timm/',
  1122. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/seresnet50_a1_0-ffa00869.pth',
  1123. crop_pct=0.95),
  1124. 'seresnet50.a2_in1k': _rcfg(
  1125. hf_hub_id='timm/',
  1126. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/seresnet50_a2_0-850de0d9.pth',
  1127. crop_pct=0.95),
  1128. 'seresnet50.a3_in1k': _r3cfg(
  1129. hf_hub_id='timm/',
  1130. url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/seresnet50_a3_0-317ecd56.pth',
  1131. crop_pct=0.95),
  1132. 'seresnet50.ra2_in1k': _ttcfg(
  1133. hf_hub_id='timm/',
  1134. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet50_ra_224-8efdb4bb.pth'),
  1135. 'seresnet50t.untrained': _ttcfg(
  1136. first_conv='conv1.0'),
  1137. 'seresnet101.untrained': _ttcfg(),
  1138. 'seresnet152.untrained': _ttcfg(),
  1139. 'seresnet152d.ra2_in1k': _ttcfg(
  1140. hf_hub_id='timm/',
  1141. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnet152d_ra2-04464dd2.pth',
  1142. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.95,
  1143. test_crop_pct=1.0, test_input_size=(3, 320, 320)
  1144. ),
  1145. 'seresnet200d.untrained': _ttcfg(
  1146. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)),
  1147. 'seresnet269d.untrained': _ttcfg(
  1148. first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8)),
  1149. # Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py
  1150. 'seresnext26d_32x4d.bt_in1k': _ttcfg(
  1151. hf_hub_id='timm/',
  1152. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth',
  1153. first_conv='conv1.0'),
  1154. 'seresnext26t_32x4d.bt_in1k': _ttcfg(
  1155. hf_hub_id='timm/',
  1156. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth',
  1157. first_conv='conv1.0'),
  1158. 'seresnext50_32x4d.racm_in1k': _ttcfg(
  1159. hf_hub_id='timm/',
  1160. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext50_32x4d_racm-a304a460.pth'),
  1161. 'seresnext101_32x4d.untrained': _ttcfg(),
  1162. 'seresnext101_32x8d.ah_in1k': _rcfg(
  1163. hf_hub_id='timm/',
  1164. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101_32x8d_ah-e6bc4c0a.pth'),
  1165. 'seresnext101d_32x8d.ah_in1k': _rcfg(
  1166. hf_hub_id='timm/',
  1167. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101d_32x8d_ah-191d7b94.pth',
  1168. first_conv='conv1.0'),
  1169. # ResNets with anti-aliasing / blur pool
  1170. 'resnetaa50d.sw_in12k_ft_in1k': _ttcfg(
  1171. hf_hub_id='timm/',
  1172. first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0),
  1173. 'resnetaa101d.sw_in12k_ft_in1k': _ttcfg(
  1174. hf_hub_id='timm/',
  1175. first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0),
  1176. 'seresnextaa101d_32x8d.sw_in12k_ft_in1k_288': _ttcfg(
  1177. hf_hub_id='timm/',
  1178. crop_pct=0.95, input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), test_crop_pct=1.0,
  1179. first_conv='conv1.0'),
  1180. 'seresnextaa101d_32x8d.sw_in12k_ft_in1k': _ttcfg(
  1181. hf_hub_id='timm/',
  1182. first_conv='conv1.0', test_crop_pct=1.0),
  1183. 'seresnextaa201d_32x8d.sw_in12k_ft_in1k_384': _cfg(
  1184. hf_hub_id='timm/',
  1185. interpolation='bicubic', first_conv='conv1.0', pool_size=(12, 12), input_size=(3, 384, 384), crop_pct=1.0),
  1186. 'seresnextaa201d_32x8d.sw_in12k': _cfg(
  1187. hf_hub_id='timm/',
  1188. num_classes=11821, interpolation='bicubic', first_conv='conv1.0',
  1189. crop_pct=0.95, input_size=(3, 320, 320), pool_size=(10, 10), test_input_size=(3, 384, 384), test_crop_pct=1.0),
  1190. 'resnetaa50d.sw_in12k': _ttcfg(
  1191. hf_hub_id='timm/',
  1192. num_classes=11821, first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0),
  1193. 'resnetaa50d.d_in12k': _ttcfg(
  1194. hf_hub_id='timm/',
  1195. num_classes=11821, first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0),
  1196. 'resnetaa101d.sw_in12k': _ttcfg(
  1197. hf_hub_id='timm/',
  1198. num_classes=11821, first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0),
  1199. 'seresnextaa101d_32x8d.sw_in12k': _ttcfg(
  1200. hf_hub_id='timm/',
  1201. num_classes=11821, first_conv='conv1.0', crop_pct=0.95, test_crop_pct=1.0),
  1202. 'resnetblur18.untrained': _ttcfg(),
  1203. 'resnetblur50.bt_in1k': _ttcfg(
  1204. hf_hub_id='timm/',
  1205. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnetblur50-84f4748f.pth'),
  1206. 'resnetblur50d.untrained': _ttcfg(first_conv='conv1.0'),
  1207. 'resnetblur101d.untrained': _ttcfg(first_conv='conv1.0'),
  1208. 'resnetaa34d.untrained': _ttcfg(first_conv='conv1.0'),
  1209. 'resnetaa50.a1h_in1k': _rcfg(
  1210. hf_hub_id='timm/',
  1211. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnetaa50_a1h-4cf422b3.pth'),
  1212. 'seresnetaa50d.untrained': _ttcfg(first_conv='conv1.0'),
  1213. 'seresnextaa101d_32x8d.ah_in1k': _rcfg(
  1214. hf_hub_id='timm/',
  1215. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnextaa101d_32x8d_ah-83c8ae12.pth',
  1216. first_conv='conv1.0'),
  1217. # ResNet-RS models
  1218. 'resnetrs50.tf_in1k': _cfg(
  1219. hf_hub_id='timm/',
  1220. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs50_ema-6b53758b.pth',
  1221. input_size=(3, 160, 160), pool_size=(5, 5), crop_pct=0.91, test_input_size=(3, 224, 224),
  1222. interpolation='bicubic', first_conv='conv1.0'),
  1223. 'resnetrs101.tf_in1k': _cfg(
  1224. hf_hub_id='timm/',
  1225. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs101_i192_ema-1509bbf6.pth',
  1226. input_size=(3, 192, 192), pool_size=(6, 6), crop_pct=0.94, test_input_size=(3, 288, 288),
  1227. interpolation='bicubic', first_conv='conv1.0'),
  1228. 'resnetrs152.tf_in1k': _cfg(
  1229. hf_hub_id='timm/',
  1230. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs152_i256_ema-a9aff7f9.pth',
  1231. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
  1232. interpolation='bicubic', first_conv='conv1.0'),
  1233. 'resnetrs200.tf_in1k': _cfg(
  1234. hf_hub_id='timm/',
  1235. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/resnetrs200_c-6b698b88.pth',
  1236. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 320, 320),
  1237. interpolation='bicubic', first_conv='conv1.0'),
  1238. 'resnetrs270.tf_in1k': _cfg(
  1239. hf_hub_id='timm/',
  1240. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs270_ema-b40e674c.pth',
  1241. input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, test_input_size=(3, 352, 352),
  1242. interpolation='bicubic', first_conv='conv1.0'),
  1243. 'resnetrs350.tf_in1k': _cfg(
  1244. hf_hub_id='timm/',
  1245. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs350_i256_ema-5a1aa8f1.pth',
  1246. input_size=(3, 288, 288), pool_size=(9, 9), crop_pct=1.0, test_input_size=(3, 384, 384),
  1247. interpolation='bicubic', first_conv='conv1.0'),
  1248. 'resnetrs420.tf_in1k': _cfg(
  1249. hf_hub_id='timm/',
  1250. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rs-weights/resnetrs420_ema-972dee69.pth',
  1251. input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, test_input_size=(3, 416, 416),
  1252. interpolation='bicubic', first_conv='conv1.0'),
  1253. # gluon resnet weights
  1254. 'resnet18.gluon_in1k': _gcfg(
  1255. hf_hub_id='timm/',
  1256. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet18_v1b-0757602b.pth'),
  1257. 'resnet34.gluon_in1k': _gcfg(
  1258. hf_hub_id='timm/',
  1259. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet34_v1b-c6d82d59.pth'),
  1260. 'resnet50.gluon_in1k': _gcfg(
  1261. hf_hub_id='timm/',
  1262. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1b-0ebe02e2.pth'),
  1263. 'resnet101.gluon_in1k': _gcfg(
  1264. hf_hub_id='timm/',
  1265. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1b-3b017079.pth'),
  1266. 'resnet152.gluon_in1k': _gcfg(
  1267. hf_hub_id='timm/',
  1268. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1b-c1edb0dd.pth'),
  1269. 'resnet50c.gluon_in1k': _gcfg(
  1270. hf_hub_id='timm/',
  1271. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1c-48092f55.pth',
  1272. first_conv='conv1.0'),
  1273. 'resnet101c.gluon_in1k': _gcfg(
  1274. hf_hub_id='timm/',
  1275. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1c-1f26822a.pth',
  1276. first_conv='conv1.0'),
  1277. 'resnet152c.gluon_in1k': _gcfg(
  1278. hf_hub_id='timm/',
  1279. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1c-a3bb0b98.pth',
  1280. first_conv='conv1.0'),
  1281. 'resnet50d.gluon_in1k': _gcfg(
  1282. hf_hub_id='timm/',
  1283. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1d-818a1b1b.pth',
  1284. first_conv='conv1.0'),
  1285. 'resnet101d.gluon_in1k': _gcfg(
  1286. hf_hub_id='timm/',
  1287. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1d-0f9c8644.pth',
  1288. first_conv='conv1.0'),
  1289. 'resnet152d.gluon_in1k': _gcfg(
  1290. hf_hub_id='timm/',
  1291. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1d-bd354e12.pth',
  1292. first_conv='conv1.0'),
  1293. 'resnet50s.gluon_in1k': _gcfg(
  1294. hf_hub_id='timm/',
  1295. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet50_v1s-1762acc0.pth',
  1296. first_conv='conv1.0'),
  1297. 'resnet101s.gluon_in1k': _gcfg(
  1298. hf_hub_id='timm/',
  1299. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet101_v1s-60fe0cc1.pth',
  1300. first_conv='conv1.0'),
  1301. 'resnet152s.gluon_in1k': _gcfg(
  1302. hf_hub_id='timm/',
  1303. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnet152_v1s-dcc41b81.pth',
  1304. first_conv='conv1.0'),
  1305. 'resnext50_32x4d.gluon_in1k': _gcfg(
  1306. hf_hub_id='timm/',
  1307. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'),
  1308. 'resnext101_32x4d.gluon_in1k': _gcfg(
  1309. hf_hub_id='timm/',
  1310. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'),
  1311. 'resnext101_64x4d.gluon_in1k': _gcfg(
  1312. hf_hub_id='timm/',
  1313. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'),
  1314. 'seresnext50_32x4d.gluon_in1k': _gcfg(
  1315. hf_hub_id='timm/',
  1316. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'),
  1317. 'seresnext101_32x4d.gluon_in1k': _gcfg(
  1318. hf_hub_id='timm/',
  1319. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'),
  1320. 'seresnext101_64x4d.gluon_in1k': _gcfg(
  1321. hf_hub_id='timm/',
  1322. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'),
  1323. 'senet154.gluon_in1k': _gcfg(
  1324. hf_hub_id='timm/',
  1325. url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth',
  1326. first_conv='conv1.0'),
  1327. 'test_resnet.r160_in1k': _cfg(
  1328. hf_hub_id='timm/',
  1329. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95,
  1330. input_size=(3, 160, 160), pool_size=(5, 5), first_conv='conv1.0'),
  1331. })
  1332. @register_model
  1333. def resnet10t(pretrained: bool = False, **kwargs) -> ResNet:
  1334. """Constructs a ResNet-10-T model.
  1335. """
  1336. model_args = dict(block=BasicBlock, layers=(1, 1, 1, 1), stem_width=32, stem_type='deep_tiered', avg_down=True)
  1337. return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs))
  1338. @register_model
  1339. def resnet14t(pretrained: bool = False, **kwargs) -> ResNet:
  1340. """Constructs a ResNet-14-T model.
  1341. """
  1342. model_args = dict(block=Bottleneck, layers=(1, 1, 1, 1), stem_width=32, stem_type='deep_tiered', avg_down=True)
  1343. return _create_resnet('resnet14t', pretrained, **dict(model_args, **kwargs))
  1344. @register_model
  1345. def resnet18(pretrained: bool = False, **kwargs) -> ResNet:
  1346. """Constructs a ResNet-18 model.
  1347. """
  1348. model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2))
  1349. return _create_resnet('resnet18', pretrained, **dict(model_args, **kwargs))
  1350. @register_model
  1351. def resnet18d(pretrained: bool = False, **kwargs) -> ResNet:
  1352. """Constructs a ResNet-18-D model.
  1353. """
  1354. model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2), stem_width=32, stem_type='deep', avg_down=True)
  1355. return _create_resnet('resnet18d', pretrained, **dict(model_args, **kwargs))
  1356. @register_model
  1357. def resnet34(pretrained: bool = False, **kwargs) -> ResNet:
  1358. """Constructs a ResNet-34 model.
  1359. """
  1360. model_args = dict(block=BasicBlock, layers=(3, 4, 6, 3))
  1361. return _create_resnet('resnet34', pretrained, **dict(model_args, **kwargs))
  1362. @register_model
  1363. def resnet34d(pretrained: bool = False, **kwargs) -> ResNet:
  1364. """Constructs a ResNet-34-D model.
  1365. """
  1366. model_args = dict(block=BasicBlock, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True)
  1367. return _create_resnet('resnet34d', pretrained, **dict(model_args, **kwargs))
  1368. @register_model
  1369. def resnet26(pretrained: bool = False, **kwargs) -> ResNet:
  1370. """Constructs a ResNet-26 model.
  1371. """
  1372. model_args = dict(block=Bottleneck, layers=(2, 2, 2, 2))
  1373. return _create_resnet('resnet26', pretrained, **dict(model_args, **kwargs))
  1374. @register_model
  1375. def resnet26t(pretrained: bool = False, **kwargs) -> ResNet:
  1376. """Constructs a ResNet-26-T model.
  1377. """
  1378. model_args = dict(block=Bottleneck, layers=(2, 2, 2, 2), stem_width=32, stem_type='deep_tiered', avg_down=True)
  1379. return _create_resnet('resnet26t', pretrained, **dict(model_args, **kwargs))
  1380. @register_model
  1381. def resnet26d(pretrained: bool = False, **kwargs) -> ResNet:
  1382. """Constructs a ResNet-26-D model.
  1383. """
  1384. model_args = dict(block=Bottleneck, layers=(2, 2, 2, 2), stem_width=32, stem_type='deep', avg_down=True)
  1385. return _create_resnet('resnet26d', pretrained, **dict(model_args, **kwargs))
  1386. @register_model
  1387. def resnet50(pretrained: bool = False, **kwargs) -> ResNet:
  1388. """Constructs a ResNet-50 model.
  1389. """
  1390. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3))
  1391. return _create_resnet('resnet50', pretrained, **dict(model_args, **kwargs))
  1392. @register_model
  1393. def resnet50c(pretrained: bool = False, **kwargs) -> ResNet:
  1394. """Constructs a ResNet-50-C model.
  1395. """
  1396. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep')
  1397. return _create_resnet('resnet50c', pretrained, **dict(model_args, **kwargs))
  1398. @register_model
  1399. def resnet50d(pretrained: bool = False, **kwargs) -> ResNet:
  1400. """Constructs a ResNet-50-D model.
  1401. """
  1402. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True)
  1403. return _create_resnet('resnet50d', pretrained, **dict(model_args, **kwargs))
  1404. @register_model
  1405. def resnet50s(pretrained: bool = False, **kwargs) -> ResNet:
  1406. """Constructs a ResNet-50-S model.
  1407. """
  1408. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=64, stem_type='deep')
  1409. return _create_resnet('resnet50s', pretrained, **dict(model_args, **kwargs))
  1410. @register_model
  1411. def resnet50t(pretrained: bool = False, **kwargs) -> ResNet:
  1412. """Constructs a ResNet-50-T model.
  1413. """
  1414. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep_tiered', avg_down=True)
  1415. return _create_resnet('resnet50t', pretrained, **dict(model_args, **kwargs))
  1416. @register_model
  1417. def resnet101(pretrained: bool = False, **kwargs) -> ResNet:
  1418. """Constructs a ResNet-101 model.
  1419. """
  1420. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3))
  1421. return _create_resnet('resnet101', pretrained, **dict(model_args, **kwargs))
  1422. @register_model
  1423. def resnet101c(pretrained: bool = False, **kwargs) -> ResNet:
  1424. """Constructs a ResNet-101-C model.
  1425. """
  1426. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep')
  1427. return _create_resnet('resnet101c', pretrained, **dict(model_args, **kwargs))
  1428. @register_model
  1429. def resnet101d(pretrained: bool = False, **kwargs) -> ResNet:
  1430. """Constructs a ResNet-101-D model.
  1431. """
  1432. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', avg_down=True)
  1433. return _create_resnet('resnet101d', pretrained, **dict(model_args, **kwargs))
  1434. @register_model
  1435. def resnet101s(pretrained: bool = False, **kwargs) -> ResNet:
  1436. """Constructs a ResNet-101-S model.
  1437. """
  1438. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), stem_width=64, stem_type='deep')
  1439. return _create_resnet('resnet101s', pretrained, **dict(model_args, **kwargs))
  1440. @register_model
  1441. def resnet152(pretrained: bool = False, **kwargs) -> ResNet:
  1442. """Constructs a ResNet-152 model.
  1443. """
  1444. model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3))
  1445. return _create_resnet('resnet152', pretrained, **dict(model_args, **kwargs))
  1446. @register_model
  1447. def resnet152c(pretrained: bool = False, **kwargs) -> ResNet:
  1448. """Constructs a ResNet-152-C model.
  1449. """
  1450. model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep')
  1451. return _create_resnet('resnet152c', pretrained, **dict(model_args, **kwargs))
  1452. @register_model
  1453. def resnet152d(pretrained: bool = False, **kwargs) -> ResNet:
  1454. """Constructs a ResNet-152-D model.
  1455. """
  1456. model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep', avg_down=True)
  1457. return _create_resnet('resnet152d', pretrained, **dict(model_args, **kwargs))
  1458. @register_model
  1459. def resnet152s(pretrained: bool = False, **kwargs) -> ResNet:
  1460. """Constructs a ResNet-152-S model.
  1461. """
  1462. model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), stem_width=64, stem_type='deep')
  1463. return _create_resnet('resnet152s', pretrained, **dict(model_args, **kwargs))
  1464. @register_model
  1465. def resnet200(pretrained: bool = False, **kwargs) -> ResNet:
  1466. """Constructs a ResNet-200 model.
  1467. """
  1468. model_args = dict(block=Bottleneck, layers=(3, 24, 36, 3))
  1469. return _create_resnet('resnet200', pretrained, **dict(model_args, **kwargs))
  1470. @register_model
  1471. def resnet200d(pretrained: bool = False, **kwargs) -> ResNet:
  1472. """Constructs a ResNet-200-D model.
  1473. """
  1474. model_args = dict(block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep', avg_down=True)
  1475. return _create_resnet('resnet200d', pretrained, **dict(model_args, **kwargs))
  1476. @register_model
  1477. def wide_resnet50_2(pretrained: bool = False, **kwargs) -> ResNet:
  1478. """Constructs a Wide ResNet-50-2 model.
  1479. The model is the same as ResNet except for the bottleneck number of channels
  1480. which is twice larger in every block. The number of channels in outer 1x1
  1481. convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
  1482. channels, and in Wide ResNet-50-2 has 2048-1024-2048.
  1483. """
  1484. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), base_width=128)
  1485. return _create_resnet('wide_resnet50_2', pretrained, **dict(model_args, **kwargs))
  1486. @register_model
  1487. def wide_resnet101_2(pretrained: bool = False, **kwargs) -> ResNet:
  1488. """Constructs a Wide ResNet-101-2 model.
  1489. The model is the same as ResNet except for the bottleneck number of channels
  1490. which is twice larger in every block. The number of channels in outer 1x1
  1491. convolutions is the same.
  1492. """
  1493. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), base_width=128)
  1494. return _create_resnet('wide_resnet101_2', pretrained, **dict(model_args, **kwargs))
  1495. @register_model
  1496. def resnet50_gn(pretrained: bool = False, **kwargs) -> ResNet:
  1497. """Constructs a ResNet-50 model w/ GroupNorm
  1498. """
  1499. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), norm_layer='groupnorm')
  1500. return _create_resnet('resnet50_gn', pretrained, **dict(model_args, **kwargs))
  1501. @register_model
  1502. def resnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1503. """Constructs a ResNeXt50-32x4d model.
  1504. """
  1505. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), cardinality=32, base_width=4)
  1506. return _create_resnet('resnext50_32x4d', pretrained, **dict(model_args, **kwargs))
  1507. @register_model
  1508. def resnext50d_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1509. """Constructs a ResNeXt50d-32x4d model. ResNext50 w/ deep stem & avg pool downsample
  1510. """
  1511. model_args = dict(
  1512. block=Bottleneck, layers=(3, 4, 6, 3), cardinality=32, base_width=4,
  1513. stem_width=32, stem_type='deep', avg_down=True)
  1514. return _create_resnet('resnext50d_32x4d', pretrained, **dict(model_args, **kwargs))
  1515. @register_model
  1516. def resnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1517. """Constructs a ResNeXt-101 32x4d model.
  1518. """
  1519. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=4)
  1520. return _create_resnet('resnext101_32x4d', pretrained, **dict(model_args, **kwargs))
  1521. @register_model
  1522. def resnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
  1523. """Constructs a ResNeXt-101 32x8d model.
  1524. """
  1525. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8)
  1526. return _create_resnet('resnext101_32x8d', pretrained, **dict(model_args, **kwargs))
  1527. @register_model
  1528. def resnext101_32x16d(pretrained: bool = False, **kwargs) -> ResNet:
  1529. """Constructs a ResNeXt-101 32x16d model
  1530. """
  1531. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=16)
  1532. return _create_resnet('resnext101_32x16d', pretrained, **dict(model_args, **kwargs))
  1533. @register_model
  1534. def resnext101_32x32d(pretrained: bool = False, **kwargs) -> ResNet:
  1535. """Constructs a ResNeXt-101 32x32d model
  1536. """
  1537. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=32)
  1538. return _create_resnet('resnext101_32x32d', pretrained, **dict(model_args, **kwargs))
  1539. @register_model
  1540. def resnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1541. """Constructs a ResNeXt101-64x4d model.
  1542. """
  1543. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), cardinality=64, base_width=4)
  1544. return _create_resnet('resnext101_64x4d', pretrained, **dict(model_args, **kwargs))
  1545. @register_model
  1546. def ecaresnet26t(pretrained: bool = False, **kwargs) -> ResNet:
  1547. """Constructs an ECA-ResNeXt-26-T model.
  1548. This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
  1549. in the deep stem and ECA attn.
  1550. """
  1551. model_args = dict(
  1552. block=Bottleneck, layers=(2, 2, 2, 2), stem_width=32,
  1553. stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
  1554. return _create_resnet('ecaresnet26t', pretrained, **dict(model_args, **kwargs))
  1555. @register_model
  1556. def ecaresnet50d(pretrained: bool = False, **kwargs) -> ResNet:
  1557. """Constructs a ResNet-50-D model with eca.
  1558. """
  1559. model_args = dict(
  1560. block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True,
  1561. block_args=dict(attn_layer='eca'))
  1562. return _create_resnet('ecaresnet50d', pretrained, **dict(model_args, **kwargs))
  1563. @register_model
  1564. def ecaresnet50d_pruned(pretrained: bool = False, **kwargs) -> ResNet:
  1565. """Constructs a ResNet-50-D model pruned with eca.
  1566. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
  1567. """
  1568. model_args = dict(
  1569. block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', avg_down=True,
  1570. block_args=dict(attn_layer='eca'))
  1571. return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs))
  1572. @register_model
  1573. def ecaresnet50t(pretrained: bool = False, **kwargs) -> ResNet:
  1574. """Constructs an ECA-ResNet-50-T model.
  1575. Like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels in the deep stem and ECA attn.
  1576. """
  1577. model_args = dict(
  1578. block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32,
  1579. stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
  1580. return _create_resnet('ecaresnet50t', pretrained, **dict(model_args, **kwargs))
  1581. @register_model
  1582. def ecaresnetlight(pretrained: bool = False, **kwargs) -> ResNet:
  1583. """Constructs a ResNet-50-D light model with eca.
  1584. """
  1585. model_args = dict(
  1586. block=Bottleneck, layers=(1, 1, 11, 3), stem_width=32, avg_down=True,
  1587. block_args=dict(attn_layer='eca'))
  1588. return _create_resnet('ecaresnetlight', pretrained, **dict(model_args, **kwargs))
  1589. @register_model
  1590. def ecaresnet101d(pretrained: bool = False, **kwargs) -> ResNet:
  1591. """Constructs a ResNet-101-D model with eca.
  1592. """
  1593. model_args = dict(
  1594. block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', avg_down=True,
  1595. block_args=dict(attn_layer='eca'))
  1596. return _create_resnet('ecaresnet101d', pretrained, **dict(model_args, **kwargs))
  1597. @register_model
  1598. def ecaresnet101d_pruned(pretrained: bool = False, **kwargs) -> ResNet:
  1599. """Constructs a ResNet-101-D model pruned with eca.
  1600. The pruning has been obtained using https://arxiv.org/pdf/2002.08258.pdf
  1601. """
  1602. model_args = dict(
  1603. block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', avg_down=True,
  1604. block_args=dict(attn_layer='eca'))
  1605. return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs))
  1606. @register_model
  1607. def ecaresnet200d(pretrained: bool = False, **kwargs) -> ResNet:
  1608. """Constructs a ResNet-200-D model with ECA.
  1609. """
  1610. model_args = dict(
  1611. block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep', avg_down=True,
  1612. block_args=dict(attn_layer='eca'))
  1613. return _create_resnet('ecaresnet200d', pretrained, **dict(model_args, **kwargs))
  1614. @register_model
  1615. def ecaresnet269d(pretrained: bool = False, **kwargs) -> ResNet:
  1616. """Constructs a ResNet-269-D model with ECA.
  1617. """
  1618. model_args = dict(
  1619. block=Bottleneck, layers=(3, 30, 48, 8), stem_width=32, stem_type='deep', avg_down=True,
  1620. block_args=dict(attn_layer='eca'))
  1621. return _create_resnet('ecaresnet269d', pretrained, **dict(model_args, **kwargs))
  1622. @register_model
  1623. def ecaresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1624. """Constructs an ECA-ResNeXt-26-T model.
  1625. This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
  1626. in the deep stem. This model replaces SE module with the ECA module
  1627. """
  1628. model_args = dict(
  1629. block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32,
  1630. stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
  1631. return _create_resnet('ecaresnext26t_32x4d', pretrained, **dict(model_args, **kwargs))
  1632. @register_model
  1633. def ecaresnext50t_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1634. """Constructs an ECA-ResNeXt-50-T model.
  1635. This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
  1636. in the deep stem. This model replaces SE module with the ECA module
  1637. """
  1638. model_args = dict(
  1639. block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32,
  1640. stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
  1641. return _create_resnet('ecaresnext50t_32x4d', pretrained, **dict(model_args, **kwargs))
  1642. @register_model
  1643. def seresnet18(pretrained: bool = False, **kwargs) -> ResNet:
  1644. model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2), block_args=dict(attn_layer='se'))
  1645. return _create_resnet('seresnet18', pretrained, **dict(model_args, **kwargs))
  1646. @register_model
  1647. def seresnet34(pretrained: bool = False, **kwargs) -> ResNet:
  1648. model_args = dict(block=BasicBlock, layers=(3, 4, 6, 3), block_args=dict(attn_layer='se'))
  1649. return _create_resnet('seresnet34', pretrained, **dict(model_args, **kwargs))
  1650. @register_model
  1651. def seresnet50(pretrained: bool = False, **kwargs) -> ResNet:
  1652. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), block_args=dict(attn_layer='se'))
  1653. return _create_resnet('seresnet50', pretrained, **dict(model_args, **kwargs))
  1654. @register_model
  1655. def seresnet50t(pretrained: bool = False, **kwargs) -> ResNet:
  1656. model_args = dict(
  1657. block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep_tiered',
  1658. avg_down=True, block_args=dict(attn_layer='se'))
  1659. return _create_resnet('seresnet50t', pretrained, **dict(model_args, **kwargs))
  1660. @register_model
  1661. def seresnet101(pretrained: bool = False, **kwargs) -> ResNet:
  1662. model_args = dict(block=Bottleneck, layers=(3, 4, 23, 3), block_args=dict(attn_layer='se'))
  1663. return _create_resnet('seresnet101', pretrained, **dict(model_args, **kwargs))
  1664. @register_model
  1665. def seresnet152(pretrained: bool = False, **kwargs) -> ResNet:
  1666. model_args = dict(block=Bottleneck, layers=(3, 8, 36, 3), block_args=dict(attn_layer='se'))
  1667. return _create_resnet('seresnet152', pretrained, **dict(model_args, **kwargs))
  1668. @register_model
  1669. def seresnet152d(pretrained: bool = False, **kwargs) -> ResNet:
  1670. model_args = dict(
  1671. block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep',
  1672. avg_down=True, block_args=dict(attn_layer='se'))
  1673. return _create_resnet('seresnet152d', pretrained, **dict(model_args, **kwargs))
  1674. @register_model
  1675. def seresnet200d(pretrained: bool = False, **kwargs) -> ResNet:
  1676. """Constructs a ResNet-200-D model with SE attn.
  1677. """
  1678. model_args = dict(
  1679. block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep',
  1680. avg_down=True, block_args=dict(attn_layer='se'))
  1681. return _create_resnet('seresnet200d', pretrained, **dict(model_args, **kwargs))
  1682. @register_model
  1683. def seresnet269d(pretrained: bool = False, **kwargs) -> ResNet:
  1684. """Constructs a ResNet-269-D model with SE attn.
  1685. """
  1686. model_args = dict(
  1687. block=Bottleneck, layers=(3, 30, 48, 8), stem_width=32, stem_type='deep',
  1688. avg_down=True, block_args=dict(attn_layer='se'))
  1689. return _create_resnet('seresnet269d', pretrained, **dict(model_args, **kwargs))
  1690. @register_model
  1691. def seresnext26d_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1692. """Constructs a SE-ResNeXt-26-D model.`
  1693. This is technically a 28 layer ResNet, using the 'D' modifier from Gluon / bag-of-tricks for
  1694. combination of deep stem and avg_pool in downsample.
  1695. """
  1696. model_args = dict(
  1697. block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32,
  1698. stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'))
  1699. return _create_resnet('seresnext26d_32x4d', pretrained, **dict(model_args, **kwargs))
  1700. @register_model
  1701. def seresnext26t_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1702. """Constructs a SE-ResNet-26-T model.
  1703. This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
  1704. in the deep stem.
  1705. """
  1706. model_args = dict(
  1707. block=Bottleneck, layers=(2, 2, 2, 2), cardinality=32, base_width=4, stem_width=32,
  1708. stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'))
  1709. return _create_resnet('seresnext26t_32x4d', pretrained, **dict(model_args, **kwargs))
  1710. @register_model
  1711. def seresnext50_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1712. model_args = dict(
  1713. block=Bottleneck, layers=(3, 4, 6, 3), cardinality=32, base_width=4,
  1714. block_args=dict(attn_layer='se'))
  1715. return _create_resnet('seresnext50_32x4d', pretrained, **dict(model_args, **kwargs))
  1716. @register_model
  1717. def seresnext101_32x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1718. model_args = dict(
  1719. block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=4,
  1720. block_args=dict(attn_layer='se'))
  1721. return _create_resnet('seresnext101_32x4d', pretrained, **dict(model_args, **kwargs))
  1722. @register_model
  1723. def seresnext101_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
  1724. model_args = dict(
  1725. block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8,
  1726. block_args=dict(attn_layer='se'))
  1727. return _create_resnet('seresnext101_32x8d', pretrained, **dict(model_args, **kwargs))
  1728. @register_model
  1729. def seresnext101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
  1730. model_args = dict(
  1731. block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8,
  1732. stem_width=32, stem_type='deep', avg_down=True,
  1733. block_args=dict(attn_layer='se'))
  1734. return _create_resnet('seresnext101d_32x8d', pretrained, **dict(model_args, **kwargs))
  1735. @register_model
  1736. def seresnext101_64x4d(pretrained: bool = False, **kwargs) -> ResNet:
  1737. model_args = dict(
  1738. block=Bottleneck, layers=(3, 4, 23, 3), cardinality=64, base_width=4,
  1739. block_args=dict(attn_layer='se'))
  1740. return _create_resnet('seresnext101_64x4d', pretrained, **dict(model_args, **kwargs))
  1741. @register_model
  1742. def senet154(pretrained: bool = False, **kwargs) -> ResNet:
  1743. model_args = dict(
  1744. block=Bottleneck, layers=(3, 8, 36, 3), cardinality=64, base_width=4, stem_type='deep',
  1745. down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'))
  1746. return _create_resnet('senet154', pretrained, **dict(model_args, **kwargs))
  1747. @register_model
  1748. def resnetblur18(pretrained: bool = False, **kwargs) -> ResNet:
  1749. """Constructs a ResNet-18 model with blur anti-aliasing
  1750. """
  1751. model_args = dict(block=BasicBlock, layers=(2, 2, 2, 2), aa_layer=BlurPool2d)
  1752. return _create_resnet('resnetblur18', pretrained, **dict(model_args, **kwargs))
  1753. @register_model
  1754. def resnetblur50(pretrained: bool = False, **kwargs) -> ResNet:
  1755. """Constructs a ResNet-50 model with blur anti-aliasing
  1756. """
  1757. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=BlurPool2d)
  1758. return _create_resnet('resnetblur50', pretrained, **dict(model_args, **kwargs))
  1759. @register_model
  1760. def resnetblur50d(pretrained: bool = False, **kwargs) -> ResNet:
  1761. """Constructs a ResNet-50-D model with blur anti-aliasing
  1762. """
  1763. model_args = dict(
  1764. block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=BlurPool2d,
  1765. stem_width=32, stem_type='deep', avg_down=True)
  1766. return _create_resnet('resnetblur50d', pretrained, **dict(model_args, **kwargs))
  1767. @register_model
  1768. def resnetblur101d(pretrained: bool = False, **kwargs) -> ResNet:
  1769. """Constructs a ResNet-101-D model with blur anti-aliasing
  1770. """
  1771. model_args = dict(
  1772. block=Bottleneck, layers=(3, 4, 23, 3), aa_layer=BlurPool2d,
  1773. stem_width=32, stem_type='deep', avg_down=True)
  1774. return _create_resnet('resnetblur101d', pretrained, **dict(model_args, **kwargs))
  1775. @register_model
  1776. def resnetaa34d(pretrained: bool = False, **kwargs) -> ResNet:
  1777. """Constructs a ResNet-34-D model w/ avgpool anti-aliasing
  1778. """
  1779. model_args = dict(
  1780. block=BasicBlock, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d, stem_width=32, stem_type='deep', avg_down=True)
  1781. return _create_resnet('resnetaa34d', pretrained, **dict(model_args, **kwargs))
  1782. @register_model
  1783. def resnetaa50(pretrained: bool = False, **kwargs) -> ResNet:
  1784. """Constructs a ResNet-50 model with avgpool anti-aliasing
  1785. """
  1786. model_args = dict(block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d)
  1787. return _create_resnet('resnetaa50', pretrained, **dict(model_args, **kwargs))
  1788. @register_model
  1789. def resnetaa50d(pretrained: bool = False, **kwargs) -> ResNet:
  1790. """Constructs a ResNet-50-D model with avgpool anti-aliasing
  1791. """
  1792. model_args = dict(
  1793. block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d,
  1794. stem_width=32, stem_type='deep', avg_down=True)
  1795. return _create_resnet('resnetaa50d', pretrained, **dict(model_args, **kwargs))
  1796. @register_model
  1797. def resnetaa101d(pretrained: bool = False, **kwargs) -> ResNet:
  1798. """Constructs a ResNet-101-D model with avgpool anti-aliasing
  1799. """
  1800. model_args = dict(
  1801. block=Bottleneck, layers=(3, 4, 23, 3), aa_layer=nn.AvgPool2d,
  1802. stem_width=32, stem_type='deep', avg_down=True)
  1803. return _create_resnet('resnetaa101d', pretrained, **dict(model_args, **kwargs))
  1804. @register_model
  1805. def seresnetaa50d(pretrained: bool = False, **kwargs) -> ResNet:
  1806. """Constructs a SE=ResNet-50-D model with avgpool anti-aliasing
  1807. """
  1808. model_args = dict(
  1809. block=Bottleneck, layers=(3, 4, 6, 3), aa_layer=nn.AvgPool2d,
  1810. stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'))
  1811. return _create_resnet('seresnetaa50d', pretrained, **dict(model_args, **kwargs))
  1812. @register_model
  1813. def seresnextaa101d_32x8d(pretrained: bool = False, **kwargs) -> ResNet:
  1814. """Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing
  1815. """
  1816. model_args = dict(
  1817. block=Bottleneck, layers=(3, 4, 23, 3), cardinality=32, base_width=8,
  1818. stem_width=32, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d,
  1819. block_args=dict(attn_layer='se'))
  1820. return _create_resnet('seresnextaa101d_32x8d', pretrained, **dict(model_args, **kwargs))
  1821. @register_model
  1822. def seresnextaa201d_32x8d(pretrained: bool = False, **kwargs):
  1823. """Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing
  1824. """
  1825. model_args = dict(
  1826. block=Bottleneck, layers=(3, 24, 36, 4), cardinality=32, base_width=8,
  1827. stem_width=64, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d,
  1828. block_args=dict(attn_layer='se'))
  1829. return _create_resnet('seresnextaa201d_32x8d', pretrained, **dict(model_args, **kwargs))
  1830. @register_model
  1831. def resnetrs50(pretrained: bool = False, **kwargs) -> ResNet:
  1832. """Constructs a ResNet-RS-50 model.
  1833. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
  1834. Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
  1835. """
  1836. attn_layer = partial(get_attn('se'), rd_ratio=0.25)
  1837. model_args = dict(
  1838. block=Bottleneck, layers=(3, 4, 6, 3), stem_width=32, stem_type='deep', replace_stem_pool=True,
  1839. avg_down=True, block_args=dict(attn_layer=attn_layer))
  1840. return _create_resnet('resnetrs50', pretrained, **dict(model_args, **kwargs))
  1841. @register_model
  1842. def resnetrs101(pretrained: bool = False, **kwargs) -> ResNet:
  1843. """Constructs a ResNet-RS-101 model.
  1844. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
  1845. Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
  1846. """
  1847. attn_layer = partial(get_attn('se'), rd_ratio=0.25)
  1848. model_args = dict(
  1849. block=Bottleneck, layers=(3, 4, 23, 3), stem_width=32, stem_type='deep', replace_stem_pool=True,
  1850. avg_down=True, block_args=dict(attn_layer=attn_layer))
  1851. return _create_resnet('resnetrs101', pretrained, **dict(model_args, **kwargs))
  1852. @register_model
  1853. def resnetrs152(pretrained: bool = False, **kwargs) -> ResNet:
  1854. """Constructs a ResNet-RS-152 model.
  1855. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
  1856. Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
  1857. """
  1858. attn_layer = partial(get_attn('se'), rd_ratio=0.25)
  1859. model_args = dict(
  1860. block=Bottleneck, layers=(3, 8, 36, 3), stem_width=32, stem_type='deep', replace_stem_pool=True,
  1861. avg_down=True, block_args=dict(attn_layer=attn_layer))
  1862. return _create_resnet('resnetrs152', pretrained, **dict(model_args, **kwargs))
  1863. @register_model
  1864. def resnetrs200(pretrained: bool = False, **kwargs) -> ResNet:
  1865. """Constructs a ResNet-RS-200 model.
  1866. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
  1867. Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
  1868. """
  1869. attn_layer = partial(get_attn('se'), rd_ratio=0.25)
  1870. model_args = dict(
  1871. block=Bottleneck, layers=(3, 24, 36, 3), stem_width=32, stem_type='deep', replace_stem_pool=True,
  1872. avg_down=True, block_args=dict(attn_layer=attn_layer))
  1873. return _create_resnet('resnetrs200', pretrained, **dict(model_args, **kwargs))
  1874. @register_model
  1875. def resnetrs270(pretrained: bool = False, **kwargs) -> ResNet:
  1876. """Constructs a ResNet-RS-270 model.
  1877. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
  1878. Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
  1879. """
  1880. attn_layer = partial(get_attn('se'), rd_ratio=0.25)
  1881. model_args = dict(
  1882. block=Bottleneck, layers=(4, 29, 53, 4), stem_width=32, stem_type='deep', replace_stem_pool=True,
  1883. avg_down=True, block_args=dict(attn_layer=attn_layer))
  1884. return _create_resnet('resnetrs270', pretrained, **dict(model_args, **kwargs))
  1885. @register_model
  1886. def resnetrs350(pretrained: bool = False, **kwargs) -> ResNet:
  1887. """Constructs a ResNet-RS-350 model.
  1888. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
  1889. Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
  1890. """
  1891. attn_layer = partial(get_attn('se'), rd_ratio=0.25)
  1892. model_args = dict(
  1893. block=Bottleneck, layers=(4, 36, 72, 4), stem_width=32, stem_type='deep', replace_stem_pool=True,
  1894. avg_down=True, block_args=dict(attn_layer=attn_layer))
  1895. return _create_resnet('resnetrs350', pretrained, **dict(model_args, **kwargs))
  1896. @register_model
  1897. def resnetrs420(pretrained: bool = False, **kwargs) -> ResNet:
  1898. """Constructs a ResNet-RS-420 model
  1899. Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
  1900. Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
  1901. """
  1902. attn_layer = partial(get_attn('se'), rd_ratio=0.25)
  1903. model_args = dict(
  1904. block=Bottleneck, layers=(4, 44, 87, 4), stem_width=32, stem_type='deep', replace_stem_pool=True,
  1905. avg_down=True, block_args=dict(attn_layer=attn_layer))
  1906. return _create_resnet('resnetrs420', pretrained, **dict(model_args, **kwargs))
  1907. @register_model
  1908. def test_resnet(pretrained: bool = False, **kwargs) -> ResNet:
  1909. """Constructs a tiny ResNet test model.
  1910. """
  1911. model_args = dict(
  1912. block=[BasicBlock, BasicBlock, Bottleneck, BasicBlock], layers=(1, 1, 1, 1),
  1913. stem_width=16, stem_type='deep', avg_down=True, channels=(32, 48, 48, 96))
  1914. return _create_resnet('test_resnet', pretrained, **dict(model_args, **kwargs))
  1915. register_model_deprecations(__name__, {
  1916. 'tv_resnet34': 'resnet34.tv_in1k',
  1917. 'tv_resnet50': 'resnet50.tv_in1k',
  1918. 'tv_resnet101': 'resnet101.tv_in1k',
  1919. 'tv_resnet152': 'resnet152.tv_in1k',
  1920. 'tv_resnext50_32x4d' : 'resnext50_32x4d.tv_in1k',
  1921. 'ig_resnext101_32x8d': 'resnext101_32x8d.fb_wsl_ig1b_ft_in1k',
  1922. 'ig_resnext101_32x16d': 'resnext101_32x8d.fb_wsl_ig1b_ft_in1k',
  1923. 'ig_resnext101_32x32d': 'resnext101_32x8d.fb_wsl_ig1b_ft_in1k',
  1924. 'ig_resnext101_32x48d': 'resnext101_32x8d.fb_wsl_ig1b_ft_in1k',
  1925. 'ssl_resnet18': 'resnet18.fb_ssl_yfcc100m_ft_in1k',
  1926. 'ssl_resnet50': 'resnet50.fb_ssl_yfcc100m_ft_in1k',
  1927. 'ssl_resnext50_32x4d': 'resnext50_32x4d.fb_ssl_yfcc100m_ft_in1k',
  1928. 'ssl_resnext101_32x4d': 'resnext101_32x4d.fb_ssl_yfcc100m_ft_in1k',
  1929. 'ssl_resnext101_32x8d': 'resnext101_32x8d.fb_ssl_yfcc100m_ft_in1k',
  1930. 'ssl_resnext101_32x16d': 'resnext101_32x16d.fb_ssl_yfcc100m_ft_in1k',
  1931. 'swsl_resnet18': 'resnet18.fb_swsl_ig1b_ft_in1k',
  1932. 'swsl_resnet50': 'resnet50.fb_swsl_ig1b_ft_in1k',
  1933. 'swsl_resnext50_32x4d': 'resnext50_32x4d.fb_swsl_ig1b_ft_in1k',
  1934. 'swsl_resnext101_32x4d': 'resnext101_32x4d.fb_swsl_ig1b_ft_in1k',
  1935. 'swsl_resnext101_32x8d': 'resnext101_32x8d.fb_swsl_ig1b_ft_in1k',
  1936. 'swsl_resnext101_32x16d': 'resnext101_32x16d.fb_swsl_ig1b_ft_in1k',
  1937. 'gluon_resnet18_v1b': 'resnet18.gluon_in1k',
  1938. 'gluon_resnet34_v1b': 'resnet34.gluon_in1k',
  1939. 'gluon_resnet50_v1b': 'resnet50.gluon_in1k',
  1940. 'gluon_resnet101_v1b': 'resnet101.gluon_in1k',
  1941. 'gluon_resnet152_v1b': 'resnet152.gluon_in1k',
  1942. 'gluon_resnet50_v1c': 'resnet50c.gluon_in1k',
  1943. 'gluon_resnet101_v1c': 'resnet101c.gluon_in1k',
  1944. 'gluon_resnet152_v1c': 'resnet152c.gluon_in1k',
  1945. 'gluon_resnet50_v1d': 'resnet50d.gluon_in1k',
  1946. 'gluon_resnet101_v1d': 'resnet101d.gluon_in1k',
  1947. 'gluon_resnet152_v1d': 'resnet152d.gluon_in1k',
  1948. 'gluon_resnet50_v1s': 'resnet50s.gluon_in1k',
  1949. 'gluon_resnet101_v1s': 'resnet101s.gluon_in1k',
  1950. 'gluon_resnet152_v1s': 'resnet152s.gluon_in1k',
  1951. 'gluon_resnext50_32x4d': 'resnext50_32x4d.gluon_in1k',
  1952. 'gluon_resnext101_32x4d': 'resnext101_32x4d.gluon_in1k',
  1953. 'gluon_resnext101_64x4d': 'resnext101_64x4d.gluon_in1k',
  1954. 'gluon_seresnext50_32x4d': 'seresnext50_32x4d.gluon_in1k',
  1955. 'gluon_seresnext101_32x4d': 'seresnext101_32x4d.gluon_in1k',
  1956. 'gluon_seresnext101_64x4d': 'seresnext101_64x4d.gluon_in1k',
  1957. 'gluon_senet154': 'senet154.gluon_in1k',
  1958. 'seresnext26tn_32x4d': 'seresnext26t_32x4d',
  1959. })