vision_transformer.py 189 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461
  1. """ Vision Transformer (ViT) in PyTorch
  2. A PyTorch implement of Vision Transformers as described in:
  3. 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
  4. - https://arxiv.org/abs/2010.11929
  5. `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
  6. - https://arxiv.org/abs/2106.10270
  7. `FlexiViT: One Model for All Patch Sizes`
  8. - https://arxiv.org/abs/2212.08013
  9. The official jax code is released and available at
  10. * https://github.com/google-research/vision_transformer
  11. * https://github.com/google-research/big_vision
  12. Acknowledgments:
  13. * The paper authors for releasing code and weights, thanks!
  14. * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch
  15. * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
  16. * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
  17. Hacked together by / Copyright 2020, Ross Wightman
  18. """
  19. import copy
  20. import logging
  21. import math
  22. import os
  23. from collections import OrderedDict
  24. from functools import partial
  25. from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List
  26. try:
  27. from typing import Literal
  28. except ImportError:
  29. from typing_extensions import Literal
  30. import torch
  31. import torch.nn as nn
  32. import torch.nn.functional as F
  33. from torch.jit import Final
  34. from timm.data import (
  35. IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD,
  36. IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD,
  37. OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
  38. )
  39. from timm.layers import (
  40. Attention,
  41. AttentionPoolLatent,
  42. PatchEmbed,
  43. Mlp,
  44. SwiGLUPacked,
  45. SwiGLU,
  46. LayerNorm,
  47. RmsNorm,
  48. DropPath,
  49. calculate_drop_path_rates,
  50. PatchDropout,
  51. trunc_normal_,
  52. lecun_normal_,
  53. resample_patch_embed,
  54. resample_abs_pos_embed,
  55. use_fused_attn,
  56. get_act_layer,
  57. get_norm_layer,
  58. maybe_add_mask,
  59. LayerType,
  60. LayerScale,
  61. )
  62. from ._builder import build_model_with_cfg
  63. from ._features import feature_take_indices
  64. from ._manipulate import named_apply, checkpoint, checkpoint_seq, adapt_input_conv
  65. from ._registry import generate_default_cfgs, register_model, register_model_deprecations
  66. __all__ = ['VisionTransformer'] # model_registry will add each entrypoint fn to this
  67. _logger = logging.getLogger(__name__)
  68. class Block(nn.Module):
  69. """Transformer block with pre-normalization."""
  70. def __init__(
  71. self,
  72. dim: int,
  73. num_heads: int,
  74. mlp_ratio: float = 4.,
  75. qkv_bias: bool = False,
  76. qk_norm: bool = False,
  77. scale_attn_norm: bool = False,
  78. scale_mlp_norm: bool = False,
  79. proj_bias: bool = True,
  80. proj_drop: float = 0.,
  81. attn_drop: float = 0.,
  82. init_values: Optional[float] = None,
  83. drop_path: float = 0.,
  84. act_layer: Type[nn.Module] = nn.GELU,
  85. norm_layer: Type[nn.Module] = LayerNorm,
  86. mlp_layer: Type[nn.Module] = Mlp,
  87. device=None,
  88. dtype=None,
  89. ) -> None:
  90. """Initialize Block.
  91. Args:
  92. dim: Number of input channels.
  93. num_heads: Number of attention heads.
  94. mlp_ratio: Ratio of mlp hidden dim to embedding dim.
  95. qkv_bias: If True, add a learnable bias to query, key, value.
  96. qk_norm: If True, apply normalization to query and key.
  97. proj_bias: If True, add bias to output projection.
  98. proj_drop: Projection dropout rate.
  99. attn_drop: Attention dropout rate.
  100. init_values: Initial values for layer scale.
  101. drop_path: Stochastic depth rate.
  102. act_layer: Activation layer.
  103. norm_layer: Normalization layer.
  104. mlp_layer: MLP layer.
  105. """
  106. super().__init__()
  107. dd = {'device': device, 'dtype': dtype}
  108. self.norm1 = norm_layer(dim, **dd)
  109. self.attn = Attention(
  110. dim,
  111. num_heads=num_heads,
  112. qkv_bias=qkv_bias,
  113. qk_norm=qk_norm,
  114. scale_norm=scale_attn_norm,
  115. proj_bias=proj_bias,
  116. attn_drop=attn_drop,
  117. proj_drop=proj_drop,
  118. norm_layer=norm_layer,
  119. **dd
  120. )
  121. self.ls1 = LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()
  122. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  123. self.norm2 = norm_layer(dim, **dd)
  124. self.mlp = mlp_layer(
  125. in_features=dim,
  126. hidden_features=int(dim * mlp_ratio),
  127. act_layer=act_layer,
  128. norm_layer=norm_layer if scale_mlp_norm else None,
  129. bias=proj_bias,
  130. drop=proj_drop,
  131. **dd,
  132. )
  133. self.ls2 = LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()
  134. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  135. def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  136. x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask)))
  137. x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
  138. return x
  139. class ResPostBlock(nn.Module):
  140. def __init__(
  141. self,
  142. dim: int,
  143. num_heads: int,
  144. mlp_ratio: float = 4.,
  145. qkv_bias: bool = False,
  146. qk_norm: bool = False,
  147. scale_attn_norm: bool = False,
  148. scale_mlp_norm: bool = False,
  149. proj_bias: bool = True,
  150. proj_drop: float = 0.,
  151. attn_drop: float = 0.,
  152. init_values: Optional[float] = None,
  153. drop_path: float = 0.,
  154. act_layer: Type[nn.Module] = nn.GELU,
  155. norm_layer: Type[nn.Module] = LayerNorm,
  156. mlp_layer: Type[nn.Module] = Mlp,
  157. device = None,
  158. dtype = None,
  159. ) -> None:
  160. super().__init__()
  161. dd = {'device': device, 'dtype': dtype}
  162. self.init_values = init_values
  163. self.attn = Attention(
  164. dim,
  165. num_heads=num_heads,
  166. qkv_bias=qkv_bias,
  167. qk_norm=qk_norm,
  168. scale_norm=scale_attn_norm,
  169. proj_bias=proj_bias,
  170. attn_drop=attn_drop,
  171. proj_drop=proj_drop,
  172. norm_layer=norm_layer,
  173. **dd,
  174. )
  175. self.norm1 = norm_layer(dim, **dd)
  176. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  177. self.mlp = mlp_layer(
  178. in_features=dim,
  179. hidden_features=int(dim * mlp_ratio),
  180. act_layer=act_layer,
  181. norm_layer=norm_layer if scale_mlp_norm else None,
  182. bias=proj_bias,
  183. drop=proj_drop,
  184. **dd,
  185. )
  186. self.norm2 = norm_layer(dim, **dd)
  187. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  188. self.init_weights()
  189. def init_weights(self) -> None:
  190. # NOTE this init overrides that base model init with specific changes for the block type
  191. if self.init_values is not None:
  192. nn.init.constant_(self.norm1.weight, self.init_values)
  193. nn.init.constant_(self.norm2.weight, self.init_values)
  194. def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  195. x = x + self.drop_path1(self.norm1(self.attn(x, attn_mask=attn_mask)))
  196. x = x + self.drop_path2(self.norm2(self.mlp(x)))
  197. return x
  198. class ParallelScalingBlock(nn.Module):
  199. """ Parallel ViT block (MLP & Attention in parallel)
  200. Based on:
  201. 'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442
  202. """
  203. fused_attn: Final[bool]
  204. def __init__(
  205. self,
  206. dim: int,
  207. num_heads: int,
  208. mlp_ratio: float = 4.,
  209. qkv_bias: bool = False,
  210. qk_norm: bool = False,
  211. scale_attn_norm: bool = False,
  212. scale_mlp_norm: bool = False,
  213. proj_bias: bool = True,
  214. proj_drop: float = 0.,
  215. attn_drop: float = 0.,
  216. init_values: Optional[float] = None,
  217. drop_path: float = 0.,
  218. act_layer: Type[nn.Module] = nn.GELU,
  219. norm_layer: Type[nn.Module] = LayerNorm,
  220. mlp_layer: Optional[Type[nn.Module]] = None,
  221. device = None,
  222. dtype = None,
  223. ) -> None:
  224. super().__init__()
  225. dd = {'device': device, 'dtype': dtype}
  226. assert dim % num_heads == 0, 'dim should be divisible by num_heads'
  227. assert not scale_attn_norm and not scale_mlp_norm, 'Scale norms not supported'
  228. self.num_heads = num_heads
  229. self.head_dim = dim // num_heads
  230. self.scale = self.head_dim ** -0.5
  231. self.fused_attn = use_fused_attn()
  232. mlp_hidden_dim = int(mlp_ratio * dim)
  233. in_proj_out_dim = mlp_hidden_dim + 3 * dim
  234. self.in_norm = norm_layer(dim, **dd)
  235. self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias, **dd)
  236. self.in_split = [mlp_hidden_dim] + [dim] * 3
  237. if qkv_bias:
  238. self.register_buffer('qkv_bias', None)
  239. self.register_parameter('mlp_bias', None)
  240. else:
  241. self.register_buffer('qkv_bias', torch.zeros(3 * dim, **dd), persistent=False)
  242. self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim, **dd))
  243. self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
  244. self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
  245. self.attn_drop = nn.Dropout(attn_drop)
  246. self.attn_out_proj = nn.Linear(dim, dim, bias=proj_bias, **dd)
  247. self.mlp_drop = nn.Dropout(proj_drop)
  248. self.mlp_act = act_layer()
  249. self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim, bias=proj_bias, **dd)
  250. self.ls = LayerScale(dim, init_values=init_values, **dd) if init_values is not None else nn.Identity()
  251. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  252. def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  253. B, N, C = x.shape
  254. # Combined MLP fc1 & qkv projections
  255. y = self.in_norm(x)
  256. if self.mlp_bias is not None:
  257. # Concat constant zero-bias for qkv w/ trainable mlp_bias.
  258. # Appears faster than adding to x_mlp separately
  259. y = F.linear(y, self.in_proj.weight, torch.cat((self.qkv_bias, self.mlp_bias)))
  260. else:
  261. y = self.in_proj(y)
  262. x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1)
  263. # Dot product attention w/ qk norm
  264. q = self.q_norm(q.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
  265. k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
  266. v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
  267. if self.fused_attn:
  268. x_attn = F.scaled_dot_product_attention(
  269. q, k, v,
  270. attn_mask=attn_mask,
  271. dropout_p=self.attn_drop.p if self.training else 0.,
  272. )
  273. else:
  274. q = q * self.scale
  275. attn = q @ k.transpose(-2, -1)
  276. attn = maybe_add_mask(attn, attn_mask)
  277. attn = attn.softmax(dim=-1)
  278. attn = self.attn_drop(attn)
  279. x_attn = attn @ v
  280. x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
  281. x_attn = self.attn_out_proj(x_attn)
  282. # MLP activation, dropout, fc2
  283. x_mlp = self.mlp_act(x_mlp)
  284. x_mlp = self.mlp_drop(x_mlp)
  285. x_mlp = self.mlp_out_proj(x_mlp)
  286. # Add residual w/ drop path & layer scale applied
  287. y = self.drop_path(self.ls(x_attn + x_mlp))
  288. x = x + y
  289. return x
  290. class ParallelThingsBlock(nn.Module):
  291. """ Parallel ViT block (N parallel attention followed by N parallel MLP)
  292. Based on:
  293. `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
  294. """
  295. def __init__(
  296. self,
  297. dim: int,
  298. num_heads: int,
  299. num_parallel: int = 2,
  300. mlp_ratio: float = 4.,
  301. qkv_bias: bool = False,
  302. qk_norm: bool = False,
  303. scale_attn_norm: bool = False,
  304. scale_mlp_norm: bool = False,
  305. proj_bias: bool = True,
  306. init_values: Optional[float] = None,
  307. proj_drop: float = 0.,
  308. attn_drop: float = 0.,
  309. drop_path: float = 0.,
  310. act_layer: Type[nn.Module] = nn.GELU,
  311. norm_layer: Type[nn.Module] = LayerNorm,
  312. mlp_layer: Type[nn.Module] = Mlp,
  313. device = None,
  314. dtype = None
  315. ) -> None:
  316. dd = {'device': device, 'dtype': dtype}
  317. super().__init__()
  318. self.num_parallel = num_parallel
  319. self.attns = nn.ModuleList()
  320. self.ffns = nn.ModuleList()
  321. for _ in range(num_parallel):
  322. self.attns.append(nn.Sequential(OrderedDict([
  323. ('norm', norm_layer(dim, **dd)),
  324. ('attn', Attention(
  325. dim,
  326. num_heads=num_heads,
  327. qkv_bias=qkv_bias,
  328. qk_norm=qk_norm,
  329. scale_norm=scale_attn_norm,
  330. proj_bias=proj_bias,
  331. attn_drop=attn_drop,
  332. proj_drop=proj_drop,
  333. norm_layer=norm_layer,
  334. **dd,
  335. )),
  336. ('ls', LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()),
  337. ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
  338. ])))
  339. self.ffns.append(nn.Sequential(OrderedDict([
  340. ('norm', norm_layer(dim, **dd)),
  341. ('mlp', mlp_layer(
  342. dim,
  343. hidden_features=int(dim * mlp_ratio),
  344. act_layer=act_layer,
  345. norm_layer=norm_layer if scale_mlp_norm else None,
  346. bias=proj_bias,
  347. drop=proj_drop,
  348. **dd,
  349. )),
  350. ('ls', LayerScale(dim, init_values=init_values, **dd) if init_values else nn.Identity()),
  351. ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
  352. ])))
  353. def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  354. if attn_mask is not None:
  355. attn_out = []
  356. for attn in self.attns:
  357. x_attn = attn.norm(x)
  358. x_attn = attn.attn(x_attn, attn_mask=attn_mask)
  359. x_attn = attn.ls(x_attn)
  360. x_attn = attn.drop_path(x_attn)
  361. attn_out.append(x_attn)
  362. x = x + torch.stack(attn_out).sum(dim=0)
  363. else:
  364. x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
  365. x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
  366. return x
  367. def global_pool_nlc(
  368. x: torch.Tensor,
  369. pool_type: str = 'token',
  370. num_prefix_tokens: int = 1,
  371. reduce_include_prefix: bool = False,
  372. ):
  373. if not pool_type:
  374. return x
  375. if pool_type == 'token':
  376. x = x[:, 0] # class token
  377. else:
  378. x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
  379. if pool_type == 'avg':
  380. x = x.mean(dim=1)
  381. elif pool_type == 'avgmax':
  382. x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
  383. elif pool_type == 'max':
  384. x = x.amax(dim=1)
  385. else:
  386. assert not pool_type, f'Unknown pool type {pool_type}'
  387. return x
  388. class VisionTransformer(nn.Module):
  389. """ Vision Transformer
  390. A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
  391. - https://arxiv.org/abs/2010.11929
  392. """
  393. dynamic_img_size: Final[bool]
  394. def __init__(
  395. self,
  396. img_size: Union[int, Tuple[int, int]] = 224,
  397. patch_size: Union[int, Tuple[int, int]] = 16,
  398. in_chans: int = 3,
  399. num_classes: int = 1000,
  400. global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token',
  401. embed_dim: int = 768,
  402. depth: int = 12,
  403. num_heads: int = 12,
  404. mlp_ratio: float = 4.,
  405. qkv_bias: bool = True,
  406. qk_norm: bool = False,
  407. scale_attn_norm: bool = False,
  408. scale_mlp_norm: bool = False,
  409. proj_bias: bool = True,
  410. init_values: Optional[float] = None,
  411. class_token: bool = True,
  412. pos_embed: str = 'learn',
  413. no_embed_class: bool = False,
  414. reg_tokens: int = 0,
  415. pre_norm: bool = False,
  416. final_norm: bool = True,
  417. fc_norm: Optional[bool] = None,
  418. pool_include_prefix: bool = False,
  419. dynamic_img_size: bool = False,
  420. dynamic_img_pad: bool = False,
  421. drop_rate: float = 0.,
  422. pos_drop_rate: float = 0.,
  423. patch_drop_rate: float = 0.,
  424. proj_drop_rate: float = 0.,
  425. attn_drop_rate: float = 0.,
  426. drop_path_rate: float = 0.,
  427. weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '',
  428. fix_init: bool = False,
  429. embed_layer: Callable = PatchEmbed,
  430. embed_norm_layer: Optional[LayerType] = None,
  431. norm_layer: Optional[LayerType] = None,
  432. act_layer: Optional[LayerType] = None,
  433. block_fn: Type[nn.Module] = Block,
  434. mlp_layer: Type[nn.Module] = Mlp,
  435. device=None,
  436. dtype=None,
  437. ) -> None:
  438. """
  439. Args:
  440. img_size: Input image size.
  441. patch_size: Patch size.
  442. in_chans: Number of image input channels.
  443. num_classes: Number of classes for classification head.
  444. global_pool: Type of global pooling for final sequence (default: 'token').
  445. embed_dim: Transformer embedding dimension.
  446. depth: Depth of transformer.
  447. num_heads: Number of attention heads.
  448. mlp_ratio: Ratio of mlp hidden dim to embedding dim.
  449. qkv_bias: Enable bias for qkv projections if True.
  450. init_values: Layer-scale init values (layer-scale enabled if not None).
  451. class_token: Use class token.
  452. no_embed_class: Don't include position embeddings for class (or reg) tokens.
  453. reg_tokens: Number of register tokens.
  454. pre_norm: Enable norm after embeddings, before transformer blocks (standard in CLIP ViT).
  455. final_norm: Enable norm after transformer blocks, before head (standard in most ViT).
  456. fc_norm: Move final norm after pool (instead of before), if None, enabled when global_pool == 'avg'.
  457. drop_rate: Head dropout rate.
  458. pos_drop_rate: Position embedding dropout rate.
  459. attn_drop_rate: Attention dropout rate.
  460. drop_path_rate: Stochastic depth rate.
  461. weight_init: Weight initialization scheme.
  462. fix_init: Apply weight initialization fix (scaling w/ layer index).
  463. embed_layer: Patch embedding layer.
  464. embed_norm_layer: Normalization layer to use / override in patch embed module.
  465. norm_layer: Normalization layer.
  466. act_layer: MLP activation layer.
  467. block_fn: Transformer block layer.
  468. """
  469. super().__init__()
  470. dd = {'device': device, 'dtype': dtype}
  471. assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
  472. assert class_token or global_pool != 'token'
  473. assert pos_embed in ('', 'none', 'learn')
  474. use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
  475. norm_layer = get_norm_layer(norm_layer) or LayerNorm
  476. embed_norm_layer = get_norm_layer(embed_norm_layer)
  477. act_layer = get_act_layer(act_layer) or nn.GELU
  478. self.num_classes = num_classes
  479. self.global_pool = global_pool
  480. self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
  481. self.num_prefix_tokens = 1 if class_token else 0
  482. self.num_prefix_tokens += reg_tokens
  483. self.num_reg_tokens = reg_tokens
  484. self.has_class_token = class_token
  485. self.no_embed_class = no_embed_class
  486. self.pool_include_prefix = pool_include_prefix
  487. self.dynamic_img_size = dynamic_img_size
  488. self.grad_checkpointing = False
  489. embed_args = {}
  490. if dynamic_img_size:
  491. # flatten deferred until after pos embed
  492. embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
  493. if embed_norm_layer is not None:
  494. embed_args['norm_layer'] = embed_norm_layer
  495. self.patch_embed = embed_layer(
  496. img_size=img_size,
  497. patch_size=patch_size,
  498. in_chans=in_chans,
  499. embed_dim=embed_dim,
  500. bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
  501. dynamic_img_pad=dynamic_img_pad,
  502. **embed_args,
  503. **dd,
  504. )
  505. num_patches = self.patch_embed.num_patches
  506. reduction = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
  507. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd)) if class_token else None
  508. self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim, **dd)) if reg_tokens else None
  509. embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
  510. if not pos_embed or pos_embed == 'none':
  511. self.pos_embed = None
  512. else:
  513. self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim, **dd) * .02)
  514. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  515. if patch_drop_rate > 0:
  516. self.patch_drop = PatchDropout(
  517. patch_drop_rate,
  518. num_prefix_tokens=self.num_prefix_tokens,
  519. )
  520. else:
  521. self.patch_drop = nn.Identity()
  522. self.norm_pre = norm_layer(embed_dim, **dd) if pre_norm else nn.Identity()
  523. dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
  524. self.blocks = nn.Sequential(*[
  525. block_fn(
  526. dim=embed_dim,
  527. num_heads=num_heads,
  528. mlp_ratio=mlp_ratio,
  529. qkv_bias=qkv_bias,
  530. qk_norm=qk_norm,
  531. scale_attn_norm=scale_attn_norm,
  532. scale_mlp_norm=scale_mlp_norm,
  533. proj_bias=proj_bias,
  534. init_values=init_values,
  535. proj_drop=proj_drop_rate,
  536. attn_drop=attn_drop_rate,
  537. drop_path=dpr[i],
  538. norm_layer=norm_layer,
  539. act_layer=act_layer,
  540. mlp_layer=mlp_layer,
  541. **dd,
  542. )
  543. for i in range(depth)])
  544. self.feature_info = [
  545. dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)]
  546. self.norm = norm_layer(embed_dim, **dd) if final_norm and not use_fc_norm else nn.Identity()
  547. # Classifier Head
  548. if global_pool == 'map':
  549. self.attn_pool = AttentionPoolLatent(
  550. self.embed_dim,
  551. num_heads=num_heads,
  552. mlp_ratio=mlp_ratio,
  553. norm_layer=norm_layer,
  554. act_layer=act_layer,
  555. **dd,
  556. )
  557. else:
  558. self.attn_pool = None
  559. self.fc_norm = norm_layer(embed_dim, **dd) if final_norm and use_fc_norm else nn.Identity()
  560. self.head_drop = nn.Dropout(drop_rate)
  561. self.head = nn.Linear(self.embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  562. if weight_init != 'skip':
  563. self.init_weights(weight_init)
  564. if fix_init:
  565. self.fix_init_weight()
  566. def fix_init_weight(self) -> None:
  567. """Apply weight initialization fix (scaling w/ layer index)."""
  568. def rescale(param, _layer_id):
  569. param.div_(math.sqrt(2.0 * _layer_id))
  570. for layer_id, layer in enumerate(self.blocks):
  571. rescale(layer.attn.proj.weight.data, layer_id + 1)
  572. rescale(layer.mlp.fc2.weight.data, layer_id + 1)
  573. def init_weights(self, mode: str = '') -> None:
  574. """Initialize model weights.
  575. Args:
  576. mode: Weight initialization mode ('jax', 'jax_nlhb', 'moco', or '').
  577. """
  578. assert mode in ('jax', 'jax_nlhb', 'moco', '')
  579. head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
  580. if self.pos_embed is not None:
  581. trunc_normal_(self.pos_embed, std=.02)
  582. if self.cls_token is not None:
  583. nn.init.normal_(self.cls_token, std=1e-6)
  584. if self.reg_token is not None:
  585. nn.init.normal_(self.reg_token, std=1e-6)
  586. named_apply(get_init_weights_vit(mode, head_bias), self)
  587. def _init_weights(self, m: nn.Module) -> None:
  588. """Initialize weights for a single module (compatibility method)."""
  589. # this fn left here for compat with downstream users
  590. init_weights_vit_timm(m)
  591. @torch.jit.ignore()
  592. def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None:
  593. """Load pretrained weights.
  594. Args:
  595. checkpoint_path: Path to checkpoint.
  596. prefix: Prefix for state dict keys.
  597. """
  598. _load_weights(self, checkpoint_path, prefix)
  599. @torch.jit.ignore
  600. def no_weight_decay(self) -> Set[str]:
  601. """Set of parameters that should not use weight decay."""
  602. return {'pos_embed', 'cls_token', 'dist_token'}
  603. @torch.jit.ignore
  604. def group_matcher(self, coarse: bool = False) -> Dict[str, Union[str, List]]:
  605. """Create regex patterns for parameter grouping.
  606. Args:
  607. coarse: Use coarse grouping.
  608. Returns:
  609. Dictionary mapping group names to regex patterns.
  610. """
  611. return dict(
  612. stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
  613. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
  614. )
  615. @torch.jit.ignore
  616. def set_grad_checkpointing(self, enable: bool = True) -> None:
  617. """Enable or disable gradient checkpointing.
  618. Args:
  619. enable: Whether to enable gradient checkpointing.
  620. """
  621. self.grad_checkpointing = enable
  622. if hasattr(self.patch_embed, 'set_grad_checkpointing'):
  623. self.patch_embed.set_grad_checkpointing(enable)
  624. @torch.jit.ignore
  625. def get_classifier(self) -> nn.Module:
  626. """Get the classifier head."""
  627. return self.head
  628. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  629. """Reset the classifier head.
  630. Args:
  631. num_classes: Number of classes for new classifier.
  632. global_pool: Global pooling type.
  633. """
  634. self.num_classes = num_classes
  635. if global_pool is not None:
  636. assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
  637. if global_pool == 'map' and self.attn_pool is None:
  638. assert False, "Cannot currently add attention pooling in reset_classifier()."
  639. elif global_pool != 'map' and self.attn_pool is not None:
  640. self.attn_pool = None # remove attention pooling
  641. self.global_pool = global_pool
  642. self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  643. def set_input_size(
  644. self,
  645. img_size: Optional[Tuple[int, int]] = None,
  646. patch_size: Optional[Tuple[int, int]] = None,
  647. ) -> None:
  648. """Update the input image resolution and patch size.
  649. Args:
  650. img_size: New input resolution, if None current resolution is used.
  651. patch_size: New patch size, if None existing patch size is used.
  652. """
  653. prev_grid_size = self.patch_embed.grid_size
  654. self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
  655. if self.pos_embed is not None:
  656. num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
  657. num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens
  658. if num_new_tokens != self.pos_embed.shape[1]:
  659. self.pos_embed = nn.Parameter(resample_abs_pos_embed(
  660. self.pos_embed,
  661. new_size=self.patch_embed.grid_size,
  662. old_size=prev_grid_size,
  663. num_prefix_tokens=num_prefix_tokens,
  664. verbose=True,
  665. ))
  666. def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
  667. """Apply positional embedding to input."""
  668. if self.pos_embed is None:
  669. return x.view(x.shape[0], -1, x.shape[-1])
  670. if self.dynamic_img_size:
  671. B, H, W, C = x.shape
  672. prev_grid_size = self.patch_embed.grid_size
  673. pos_embed = resample_abs_pos_embed(
  674. self.pos_embed,
  675. new_size=(H, W),
  676. old_size=prev_grid_size,
  677. num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
  678. )
  679. x = x.view(B, -1, C)
  680. else:
  681. pos_embed = self.pos_embed
  682. to_cat = []
  683. if self.cls_token is not None:
  684. to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
  685. if self.reg_token is not None:
  686. to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
  687. if self.no_embed_class:
  688. # deit-3, updated JAX (big vision)
  689. # position embedding does not overlap with class token, add then concat
  690. x = x + pos_embed
  691. if to_cat:
  692. x = torch.cat(to_cat + [x], dim=1)
  693. else:
  694. # original timm, JAX, and deit vit impl
  695. # pos_embed has entry for class token, concat then add
  696. if to_cat:
  697. x = torch.cat(to_cat + [x], dim=1)
  698. x = x + pos_embed
  699. return self.pos_drop(x)
  700. def forward_intermediates(
  701. self,
  702. x: torch.Tensor,
  703. indices: Optional[Union[int, List[int]]] = None,
  704. return_prefix_tokens: bool = False,
  705. norm: bool = False,
  706. stop_early: bool = False,
  707. output_fmt: str = 'NCHW',
  708. intermediates_only: bool = False,
  709. output_dict: bool = False,
  710. attn_mask: Optional[torch.Tensor] = None,
  711. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]:
  712. """ Forward features that returns intermediates.
  713. Args:
  714. x: Input image tensor
  715. indices: Take last n blocks if int, all if None, select matching indices if sequence
  716. return_prefix_tokens: Return both prefix and spatial intermediate tokens
  717. norm: Apply norm layer to all intermediates
  718. stop_early: Stop iterating over blocks when last desired intermediate hit
  719. output_fmt: Shape of intermediate feature outputs
  720. intermediates_only: Only return intermediate features
  721. output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys
  722. attn_mask: Optional attention mask for masked attention (e.g., for NaFlex)
  723. Returns:
  724. A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing
  725. 'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix')
  726. """
  727. assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
  728. reshape = output_fmt == 'NCHW'
  729. intermediates = []
  730. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  731. # forward pass
  732. B, _, height, width = x.shape
  733. x = self.patch_embed(x)
  734. x = self._pos_embed(x)
  735. x = self.patch_drop(x)
  736. x = self.norm_pre(x)
  737. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  738. blocks = self.blocks
  739. else:
  740. blocks = self.blocks[:max_index + 1]
  741. for i, blk in enumerate(blocks):
  742. if attn_mask is not None:
  743. x = blk(x, attn_mask=attn_mask)
  744. elif self.grad_checkpointing and not torch.jit.is_scripting():
  745. x = checkpoint(blk, x)
  746. else:
  747. x = blk(x)
  748. if i in take_indices:
  749. # normalize intermediates with final norm layer if enabled
  750. intermediates.append(self.norm(x) if norm else x)
  751. # process intermediates
  752. if self.num_prefix_tokens:
  753. # split prefix (e.g. class, distill) and spatial feature tokens
  754. prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
  755. intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
  756. else:
  757. prefix_tokens = None
  758. if reshape:
  759. # reshape to BCHW output format
  760. H, W = self.patch_embed.dynamic_feat_size((height, width))
  761. intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
  762. # For dictionary output, handle prefix tokens separately
  763. if output_dict:
  764. result_dict = {}
  765. # Intermediates are always included
  766. result_dict['image_intermediates'] = intermediates
  767. if prefix_tokens is not None and return_prefix_tokens:
  768. result_dict['image_intermediates_prefix'] = prefix_tokens
  769. # Only include features if not intermediates_only
  770. if not intermediates_only:
  771. x_final = self.norm(x)
  772. result_dict['image_features'] = x_final
  773. return result_dict
  774. # For non-dictionary output, maintain the original behavior
  775. if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
  776. # return_prefix not support in torchscript due to poor type handling
  777. intermediates = list(zip(intermediates, prefix_tokens))
  778. if intermediates_only:
  779. return intermediates
  780. x = self.norm(x)
  781. return x, intermediates
  782. def prune_intermediate_layers(
  783. self,
  784. indices: Union[int, List[int]] = 1,
  785. prune_norm: bool = False,
  786. prune_head: bool = True,
  787. ) -> List[int]:
  788. """Prune layers not required for specified intermediates.
  789. Args:
  790. indices: Indices of intermediate layers to keep.
  791. prune_norm: Whether to prune normalization layer.
  792. prune_head: Whether to prune the classifier head.
  793. Returns:
  794. List of indices that were kept.
  795. """
  796. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  797. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  798. if prune_norm:
  799. self.norm = nn.Identity()
  800. if prune_head:
  801. self.fc_norm = nn.Identity()
  802. self.reset_classifier(0, '')
  803. return take_indices
  804. def get_intermediate_layers(
  805. self,
  806. x: torch.Tensor,
  807. n: Union[int, List[int], Tuple[int]] = 1,
  808. reshape: bool = False,
  809. return_prefix_tokens: bool = False,
  810. norm: bool = False,
  811. attn_mask: Optional[torch.Tensor] = None,
  812. ) -> List[torch.Tensor]:
  813. """Get intermediate layer outputs (DINO interface compatibility).
  814. NOTE: This API is for backwards compat, favour using forward_intermediates() directly.
  815. Args:
  816. x: Input tensor.
  817. n: Number or indices of layers.
  818. reshape: Reshape to NCHW format.
  819. return_prefix_tokens: Return prefix tokens.
  820. norm: Apply normalization.
  821. Returns:
  822. List of intermediate features.
  823. """
  824. return self.forward_intermediates(
  825. x, n,
  826. return_prefix_tokens=return_prefix_tokens,
  827. norm=norm,
  828. output_fmt='NCHW' if reshape else 'NLC',
  829. intermediates_only=True,
  830. attn_mask=attn_mask,
  831. )
  832. def forward_features(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  833. """Forward pass through feature layers (embeddings, transformer blocks, post-transformer norm)."""
  834. x = self.patch_embed(x)
  835. x = self._pos_embed(x)
  836. x = self.patch_drop(x)
  837. x = self.norm_pre(x)
  838. if attn_mask is not None:
  839. # If mask provided, we need to apply blocks one by one
  840. for blk in self.blocks:
  841. x = blk(x, attn_mask=attn_mask)
  842. elif self.grad_checkpointing and not torch.jit.is_scripting():
  843. x = checkpoint_seq(self.blocks, x)
  844. else:
  845. x = self.blocks(x)
  846. x = self.norm(x)
  847. return x
  848. def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
  849. """Apply pooling to feature tokens.
  850. Args:
  851. x: Feature tensor.
  852. pool_type: Pooling type override.
  853. Returns:
  854. Pooled features.
  855. """
  856. if self.attn_pool is not None:
  857. if not self.pool_include_prefix:
  858. x = x[:, self.num_prefix_tokens:]
  859. x = self.attn_pool(x)
  860. return x
  861. pool_type = self.global_pool if pool_type is None else pool_type
  862. x = global_pool_nlc(
  863. x,
  864. pool_type=pool_type,
  865. num_prefix_tokens=self.num_prefix_tokens,
  866. reduce_include_prefix=self.pool_include_prefix,
  867. )
  868. return x
  869. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  870. """Forward pass through classifier head.
  871. Args:
  872. x: Feature tensor.
  873. pre_logits: Return features before final classifier.
  874. Returns:
  875. Output tensor.
  876. """
  877. x = self.pool(x)
  878. x = self.fc_norm(x)
  879. x = self.head_drop(x)
  880. return x if pre_logits else self.head(x)
  881. def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
  882. x = self.forward_features(x, attn_mask=attn_mask)
  883. x = self.forward_head(x)
  884. return x
  885. def init_weights_vit_timm(module: nn.Module, name: str = '') -> None:
  886. """ViT weight initialization, original timm impl (for reproducibility).
  887. Args:
  888. module: Module to initialize.
  889. name: Module name for context.
  890. """
  891. if isinstance(module, nn.Linear):
  892. trunc_normal_(module.weight, std=.02)
  893. if module.bias is not None:
  894. nn.init.zeros_(module.bias)
  895. elif hasattr(module, 'init_weights'):
  896. module.init_weights()
  897. def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.0) -> None:
  898. """ViT weight initialization, matching JAX (Flax) impl.
  899. Args:
  900. module: Module to initialize.
  901. name: Module name for context.
  902. head_bias: Bias value for head layer.
  903. """
  904. if isinstance(module, nn.Linear):
  905. if name.startswith('head'):
  906. nn.init.zeros_(module.weight)
  907. nn.init.constant_(module.bias, head_bias)
  908. else:
  909. nn.init.xavier_uniform_(module.weight)
  910. if module.bias is not None:
  911. nn.init.normal_(module.bias, std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
  912. elif isinstance(module, nn.Conv2d):
  913. lecun_normal_(module.weight)
  914. if module.bias is not None:
  915. nn.init.zeros_(module.bias)
  916. elif hasattr(module, 'init_weights'):
  917. module.init_weights()
  918. def init_weights_vit_moco(module: nn.Module, name: str = '') -> None:
  919. """ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed.
  920. Args:
  921. module: Module to initialize.
  922. name: Module name for context.
  923. """
  924. if isinstance(module, nn.Linear):
  925. if 'qkv' in name:
  926. # treat the weights of Q, K, V separately
  927. val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
  928. nn.init.uniform_(module.weight, -val, val)
  929. else:
  930. nn.init.xavier_uniform_(module.weight)
  931. if module.bias is not None:
  932. nn.init.zeros_(module.bias)
  933. elif hasattr(module, 'init_weights'):
  934. module.init_weights()
  935. def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
  936. if 'jax' in mode:
  937. return partial(init_weights_vit_jax, head_bias=head_bias)
  938. elif 'moco' in mode:
  939. return init_weights_vit_moco
  940. else:
  941. return init_weights_vit_timm
  942. def resize_pos_embed(
  943. posemb: torch.Tensor,
  944. posemb_new: torch.Tensor,
  945. num_prefix_tokens: int = 1,
  946. gs_new: Tuple[int, int] = (),
  947. interpolation: str = 'bicubic',
  948. antialias: bool = False,
  949. ) -> torch.Tensor:
  950. """ Rescale the grid of position embeddings when loading from state_dict.
  951. *DEPRECATED* This function is being deprecated in favour of using resample_abs_pos_embed
  952. """
  953. ntok_new = posemb_new.shape[1] - num_prefix_tokens
  954. ntok_old = posemb.shape[1] - num_prefix_tokens
  955. gs_old = [int(math.sqrt(ntok_old))] * 2
  956. if not len(gs_new): # backwards compatibility
  957. gs_new = [int(math.sqrt(ntok_new))] * 2
  958. return resample_abs_pos_embed(
  959. posemb, gs_new, gs_old,
  960. num_prefix_tokens=num_prefix_tokens,
  961. interpolation=interpolation,
  962. antialias=antialias,
  963. verbose=True,
  964. )
  965. @torch.no_grad()
  966. def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = '', load_bfloat16: bool = False) -> None:
  967. """ Load weights from .npz checkpoints for official Google Brain Flax implementation
  968. """
  969. import numpy as np
  970. if load_bfloat16:
  971. import jax.numpy as jnp
  972. import ml_dtypes
  973. def _n2p(_w, t=True, idx=None):
  974. if idx is not None:
  975. _w = _w[idx]
  976. if load_bfloat16:
  977. _w = _w.view(ml_dtypes.bfloat16).astype(jnp.float32)
  978. _w = np.array(_w)
  979. if _w.ndim == 4 and _w.shape[0] == _w.shape[1] == _w.shape[2] == 1:
  980. _w = _w.flatten()
  981. if t:
  982. if _w.ndim == 4:
  983. _w = _w.transpose([3, 2, 0, 1])
  984. elif _w.ndim == 3:
  985. _w = _w.transpose([2, 0, 1])
  986. elif _w.ndim == 2:
  987. _w = _w.transpose([1, 0])
  988. _w = torch.from_numpy(_w)
  989. return _w
  990. if load_bfloat16:
  991. w = jnp.load(checkpoint_path)
  992. else:
  993. w = np.load(checkpoint_path)
  994. interpolation = 'bilinear'
  995. antialias = False
  996. big_vision = False
  997. if not prefix:
  998. if 'opt/target/embedding/kernel' in w:
  999. prefix = 'opt/target/'
  1000. elif 'params/embedding/kernel' in w:
  1001. prefix = 'params/'
  1002. big_vision = True
  1003. elif 'params/img/embedding/kernel' in w:
  1004. prefix = 'params/img/'
  1005. big_vision = True
  1006. if hasattr(model.patch_embed, 'backbone'):
  1007. # hybrid
  1008. backbone = model.patch_embed.backbone
  1009. stem_only = not hasattr(backbone, 'stem')
  1010. stem = backbone if stem_only else backbone.stem
  1011. stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
  1012. stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
  1013. stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
  1014. if not stem_only:
  1015. for i, stage in enumerate(backbone.stages):
  1016. for j, block in enumerate(stage.blocks):
  1017. bp = f'{prefix}block{i + 1}/unit{j + 1}/'
  1018. for r in range(3):
  1019. getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
  1020. getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
  1021. getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
  1022. if block.downsample is not None:
  1023. block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
  1024. block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
  1025. block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
  1026. embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
  1027. else:
  1028. embed_conv_w = adapt_input_conv(
  1029. model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
  1030. if embed_conv_w.shape[-2:] != model.patch_embed.proj.weight.shape[-2:]:
  1031. embed_conv_w = resample_patch_embed(
  1032. embed_conv_w,
  1033. model.patch_embed.proj.weight.shape[-2:],
  1034. interpolation=interpolation,
  1035. antialias=antialias,
  1036. verbose=True,
  1037. )
  1038. model.patch_embed.proj.weight.copy_(embed_conv_w)
  1039. model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
  1040. if model.cls_token is not None:
  1041. model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
  1042. if big_vision:
  1043. pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False)
  1044. else:
  1045. pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
  1046. if pos_embed_w.shape != model.pos_embed.shape:
  1047. num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
  1048. pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights
  1049. pos_embed_w,
  1050. new_size=model.patch_embed.grid_size,
  1051. num_prefix_tokens=num_prefix_tokens,
  1052. interpolation=interpolation,
  1053. antialias=antialias,
  1054. verbose=True,
  1055. )
  1056. model.pos_embed.copy_(pos_embed_w)
  1057. model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
  1058. model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
  1059. if (isinstance(model.head, nn.Linear) and
  1060. f'{prefix}head/bias' in w and
  1061. model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]):
  1062. model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
  1063. model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
  1064. # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights
  1065. # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
  1066. # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
  1067. # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
  1068. if model.attn_pool is not None:
  1069. block_prefix = f'{prefix}MAPHead_0/'
  1070. mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/'
  1071. model.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False))
  1072. model.attn_pool.kv.weight.copy_(torch.cat([
  1073. _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')]))
  1074. model.attn_pool.kv.bias.copy_(torch.cat([
  1075. _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')]))
  1076. model.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T)
  1077. model.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1))
  1078. model.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
  1079. model.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
  1080. model.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
  1081. model.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
  1082. for r in range(2):
  1083. getattr(model.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel']))
  1084. getattr(model.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias']))
  1085. mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
  1086. for i, block in enumerate(model.blocks.children()):
  1087. if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
  1088. block_prefix = f'{prefix}Transformer/encoderblock/'
  1089. idx = i
  1090. else:
  1091. block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
  1092. idx = None
  1093. mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
  1094. block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
  1095. block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
  1096. block.attn.qkv.weight.copy_(torch.cat([
  1097. _n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
  1098. block.attn.qkv.bias.copy_(torch.cat([
  1099. _n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
  1100. block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
  1101. block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
  1102. block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
  1103. block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
  1104. for r in range(2):
  1105. getattr(block.mlp, f'fc{r + 1}').weight.copy_(
  1106. _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
  1107. getattr(block.mlp, f'fc{r + 1}').bias.copy_(
  1108. _n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))
  1109. def _convert_openai_clip(
  1110. state_dict: Dict[str, torch.Tensor],
  1111. model: VisionTransformer,
  1112. prefix: str = 'visual.',
  1113. ) -> Dict[str, torch.Tensor]:
  1114. out_dict = {}
  1115. swaps = [
  1116. ('conv1', 'patch_embed.proj'),
  1117. ('positional_embedding', 'pos_embed'),
  1118. ('transformer.resblocks.', 'blocks.'),
  1119. ('ln_pre', 'norm_pre'),
  1120. ('ln_post', 'norm'),
  1121. ('ln_', 'norm'),
  1122. ('in_proj_', 'qkv.'),
  1123. ('out_proj', 'proj'),
  1124. ('mlp.c_fc', 'mlp.fc1'),
  1125. ('mlp.c_proj', 'mlp.fc2'),
  1126. ]
  1127. for k, v in state_dict.items():
  1128. if not k.startswith(prefix):
  1129. continue
  1130. k = k.replace(prefix, '')
  1131. for sp in swaps:
  1132. k = k.replace(sp[0], sp[1])
  1133. if k == 'proj':
  1134. k = 'head.weight'
  1135. v = v.transpose(0, 1)
  1136. out_dict['head.bias'] = torch.zeros(v.shape[0])
  1137. elif k == 'class_embedding':
  1138. k = 'cls_token'
  1139. v = v.unsqueeze(0).unsqueeze(1)
  1140. elif k == 'pos_embed':
  1141. v = v.unsqueeze(0)
  1142. out_dict[k] = v
  1143. return out_dict
  1144. def _convert_dinov2(
  1145. state_dict: Dict[str, torch.Tensor],
  1146. model: VisionTransformer,
  1147. ) -> Dict[str, torch.Tensor]:
  1148. import re
  1149. out_dict = {}
  1150. state_dict.pop("mask_token", None)
  1151. if 'register_tokens' in state_dict:
  1152. # convert dinov2 w/ registers to no_embed_class timm model (neither cls or reg tokens overlap pos embed)
  1153. out_dict['reg_token'] = state_dict.pop('register_tokens')
  1154. out_dict['cls_token'] = state_dict.pop('cls_token') + state_dict['pos_embed'][:, 0]
  1155. out_dict['pos_embed'] = state_dict.pop('pos_embed')[:, 1:]
  1156. for k, v in state_dict.items():
  1157. if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
  1158. out_dict[k.replace("w12", "fc1")] = v
  1159. continue
  1160. elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
  1161. out_dict[k.replace("w3", "fc2")] = v
  1162. continue
  1163. out_dict[k] = v
  1164. return out_dict
  1165. def _convert_aimv2(
  1166. state_dict: Dict[str, torch.Tensor],
  1167. model: VisionTransformer,
  1168. ) -> Dict[str, torch.Tensor]:
  1169. out_dict = {}
  1170. for k, v in state_dict.items():
  1171. k = k.replace('norm_1', 'norm1')
  1172. k = k.replace('norm_2', 'norm2')
  1173. k = k.replace('preprocessor.patchifier.', 'patch_embed.')
  1174. k = k.replace('preprocessor.pos_embed', 'pos_embed')
  1175. k = k.replace('trunk.', '')
  1176. k = k.replace('post_trunk_norm.', 'norm.')
  1177. k = k.replace('mlp.fc1', 'mlp.fc1_g')
  1178. k = k.replace('mlp.fc3', 'mlp.fc1_x')
  1179. out_dict[k] = v
  1180. return out_dict
  1181. def _convert_beit3(state_dict: dict, model):
  1182. """
  1183. Turn a BEiT-3 checkpoint into a standard VisionTransformer state-dict.
  1184. """
  1185. import re
  1186. state_dict = state_dict.get("model", state_dict) # unwrap if needed
  1187. # Prune unused
  1188. for k in ("beit3.text_embed.weight", "beit3.vision_embed.mask_token"):
  1189. state_dict.pop(k, None)
  1190. # Key renaming rules
  1191. rules = [
  1192. (r"beit3\.", ""),
  1193. (r"vision_embed\.cls_token", "cls_token"),
  1194. (r"vision_embed\.", "patch_embed."),
  1195. (r"embed_positions\.", "pos_embed."),
  1196. (r"encoder\.", ""),
  1197. (r"layers\.", "blocks."),
  1198. (r"ffn_layernorm\.", "norm."), (r"ffn\.", "mlp."),
  1199. (r"self_attn_layer_norm\.", "norm1."), (r"self_attn\.", "attn."),
  1200. (r"final_layer_norm\.", "norm2."),
  1201. (r"inner_attn_ln", "norm"),
  1202. (r"out_proj", "proj"),
  1203. (r"\.A\.", "."),
  1204. ]
  1205. # First pass, rename keys
  1206. tmp = {}
  1207. for k, v in state_dict.items():
  1208. if ".B." in k:
  1209. continue # use branch-A only
  1210. for old, new in rules:
  1211. k = re.sub(old, new, k)
  1212. if k == "pos_embed.weight":
  1213. # strip first two positions, [1, N+1, D]
  1214. tmp["pos_embed"] = v[2:].unsqueeze(0)
  1215. else:
  1216. tmp[k] = v
  1217. # Second pass, fuse q, k, v
  1218. out, buf = {}, {}
  1219. pat = re.compile(r"blocks\.(\d+)\.attn\.(q|k|v)_proj\.(weight|bias)$")
  1220. for k, v in tmp.items():
  1221. m = pat.fullmatch(k)
  1222. if not m: # anything not q/k/v -> copy through
  1223. out[k] = v
  1224. continue
  1225. blk, which, kind = m.groups() # block idx, 'q'/'k'/'v', 'weight'/'bias'
  1226. stash = buf.setdefault((blk, kind), {}) # Gather by block & param type
  1227. stash[which] = v
  1228. if len(stash) == 3: # Have q, k, v -> concatenate
  1229. out[f"blocks.{blk}.attn.qkv.{kind}"] = torch.cat(
  1230. [stash['q'], stash['k'], stash['v']], dim=0
  1231. )
  1232. return out
  1233. def checkpoint_filter_fn(
  1234. state_dict: Dict[str, torch.Tensor],
  1235. model: VisionTransformer,
  1236. adapt_layer_scale: bool = False,
  1237. interpolation: str = 'bicubic',
  1238. antialias: bool = True,
  1239. ) -> Dict[str, torch.Tensor]:
  1240. """ convert patch embedding weight from manual patchify + linear proj to conv"""
  1241. import re
  1242. out_dict = {}
  1243. state_dict = state_dict.get('model', state_dict)
  1244. state_dict = state_dict.get('state_dict', state_dict)
  1245. prefix = ''
  1246. if 'visual.class_embedding' in state_dict:
  1247. state_dict = _convert_openai_clip(state_dict, model)
  1248. elif 'module.visual.class_embedding' in state_dict:
  1249. state_dict = _convert_openai_clip(state_dict, model, prefix='module.visual.')
  1250. elif "mask_token" in state_dict:
  1251. state_dict = _convert_dinov2(state_dict, model)
  1252. elif any('beit3.' in k for k in state_dict.keys()):
  1253. # BEiT3 model - multimodal checkpoint with beit3.* prefix
  1254. state_dict = _convert_beit3(state_dict, model)
  1255. elif "encoder" in state_dict:
  1256. # IJEPA, vit in an 'encoder' submodule
  1257. state_dict = state_dict['encoder']
  1258. prefix = 'module.'
  1259. elif 'visual.trunk.pos_embed' in state_dict or 'visual.trunk.blocks.0.norm1.weight' in state_dict:
  1260. # OpenCLIP model with timm vision encoder
  1261. prefix = 'visual.trunk.'
  1262. if 'visual.head.proj.weight' in state_dict and isinstance(model.head, nn.Linear):
  1263. # remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
  1264. out_dict['head.weight'] = state_dict['visual.head.proj.weight']
  1265. out_dict['head.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
  1266. elif 'module.visual.trunk.pos_embed' in state_dict:
  1267. prefix = 'module.visual.trunk.'
  1268. elif 'preprocessor.patchifier.proj.weight' in state_dict:
  1269. state_dict = _convert_aimv2(state_dict, model)
  1270. if prefix:
  1271. # filter on & remove prefix string from keys
  1272. state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
  1273. for k, v in state_dict.items():
  1274. if 'patch_embed.proj.weight' in k:
  1275. O, I, H, W = model.patch_embed.proj.weight.shape
  1276. if len(v.shape) < 4:
  1277. # For old models that I trained prior to conv based patchification
  1278. O, I, H, W = model.patch_embed.proj.weight.shape
  1279. v = v.reshape(O, -1, H, W)
  1280. if v.shape[-1] != W or v.shape[-2] != H:
  1281. v = resample_patch_embed(
  1282. v,
  1283. (H, W),
  1284. interpolation=interpolation,
  1285. antialias=antialias,
  1286. verbose=True,
  1287. )
  1288. elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
  1289. # To resize pos embedding when using model at different size from pretrained weights
  1290. num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
  1291. v = resample_abs_pos_embed(
  1292. v,
  1293. new_size=model.patch_embed.grid_size,
  1294. num_prefix_tokens=num_prefix_tokens,
  1295. interpolation=interpolation,
  1296. antialias=antialias,
  1297. verbose=True,
  1298. )
  1299. elif adapt_layer_scale and 'gamma_' in k:
  1300. # remap layer-scale gamma into sub-module (deit3 models)
  1301. k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
  1302. elif 'pre_logits' in k:
  1303. # NOTE representation layer removed as not used in latest 21k/1k pretrained weights
  1304. continue
  1305. out_dict[k] = v
  1306. return out_dict
  1307. def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  1308. return {
  1309. 'url': url,
  1310. 'num_classes': 1000,
  1311. 'input_size': (3, 224, 224),
  1312. 'pool_size': None,
  1313. 'crop_pct': 0.9,
  1314. 'interpolation': 'bicubic',
  1315. 'fixed_input_size': True,
  1316. 'mean': IMAGENET_INCEPTION_MEAN,
  1317. 'std': IMAGENET_INCEPTION_STD,
  1318. 'first_conv': 'patch_embed.proj',
  1319. 'classifier': 'head',
  1320. 'license': 'apache-2.0',
  1321. **kwargs,
  1322. }
  1323. default_cfgs = {
  1324. # re-finetuned augreg 21k FT on in1k weights
  1325. 'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg(
  1326. hf_hub_id='timm/'),
  1327. 'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(),
  1328. 'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg(
  1329. hf_hub_id='timm/'),
  1330. # How to train your ViT (augreg) weights, pretrained on 21k FT on in1k
  1331. 'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg(
  1332. url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
  1333. hf_hub_id='timm/',
  1334. custom_load=True),
  1335. 'vit_tiny_patch16_384.augreg_in21k_ft_in1k': _cfg(
  1336. url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
  1337. hf_hub_id='timm/',
  1338. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1339. 'vit_small_patch32_224.augreg_in21k_ft_in1k': _cfg(
  1340. url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
  1341. hf_hub_id='timm/',
  1342. custom_load=True),
  1343. 'vit_small_patch32_384.augreg_in21k_ft_in1k': _cfg(
  1344. url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
  1345. hf_hub_id='timm/',
  1346. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1347. 'vit_small_patch16_224.augreg_in21k_ft_in1k': _cfg(
  1348. url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
  1349. hf_hub_id='timm/',
  1350. custom_load=True),
  1351. 'vit_small_patch16_384.augreg_in21k_ft_in1k': _cfg(
  1352. url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
  1353. hf_hub_id='timm/',
  1354. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1355. 'vit_base_patch32_224.augreg_in21k_ft_in1k': _cfg(
  1356. url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
  1357. hf_hub_id='timm/',
  1358. custom_load=True),
  1359. 'vit_base_patch32_384.augreg_in21k_ft_in1k': _cfg(
  1360. url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
  1361. hf_hub_id='timm/',
  1362. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1363. 'vit_base_patch16_224.augreg_in21k_ft_in1k': _cfg(
  1364. url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
  1365. hf_hub_id='timm/',
  1366. custom_load=True),
  1367. 'vit_base_patch16_384.augreg_in21k_ft_in1k': _cfg(
  1368. url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
  1369. hf_hub_id='timm/',
  1370. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1371. 'vit_base_patch8_224.augreg_in21k_ft_in1k': _cfg(
  1372. url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
  1373. hf_hub_id='timm/',
  1374. custom_load=True),
  1375. 'vit_large_patch16_224.augreg_in21k_ft_in1k': _cfg(
  1376. url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
  1377. hf_hub_id='timm/',
  1378. custom_load=True),
  1379. 'vit_large_patch16_384.augreg_in21k_ft_in1k': _cfg(
  1380. url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
  1381. hf_hub_id='timm/',
  1382. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1383. # patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k
  1384. 'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg(
  1385. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
  1386. hf_hub_id='timm/'),
  1387. 'vit_base_patch16_384.orig_in21k_ft_in1k': _cfg(
  1388. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
  1389. hf_hub_id='timm/',
  1390. input_size=(3, 384, 384), crop_pct=1.0),
  1391. 'vit_large_patch32_384.orig_in21k_ft_in1k': _cfg(
  1392. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
  1393. hf_hub_id='timm/',
  1394. input_size=(3, 384, 384), crop_pct=1.0),
  1395. # How to train your ViT (augreg) weights trained on in1k only
  1396. 'vit_small_patch16_224.augreg_in1k': _cfg(
  1397. url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz',
  1398. hf_hub_id='timm/',
  1399. custom_load=True),
  1400. 'vit_small_patch16_384.augreg_in1k': _cfg(
  1401. url='https://storage.googleapis.com/vit_models/augreg/S_16-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
  1402. hf_hub_id='timm/',
  1403. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1404. 'vit_base_patch32_224.augreg_in1k': _cfg(
  1405. url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
  1406. hf_hub_id='timm/',
  1407. custom_load=True),
  1408. 'vit_base_patch32_384.augreg_in1k': _cfg(
  1409. url='https://storage.googleapis.com/vit_models/augreg/B_32-i1k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
  1410. hf_hub_id='timm/',
  1411. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1412. 'vit_base_patch16_224.augreg_in1k': _cfg(
  1413. url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
  1414. hf_hub_id='timm/',
  1415. custom_load=True),
  1416. 'vit_base_patch16_384.augreg_in1k': _cfg(
  1417. url='https://storage.googleapis.com/vit_models/augreg/B_16-i1k-300ep-lr_0.001-aug_strong2-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
  1418. hf_hub_id='timm/',
  1419. custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
  1420. 'vit_large_patch14_224.untrained': _cfg(url=''),
  1421. 'vit_huge_patch14_224.untrained': _cfg(url=''),
  1422. 'vit_giant_patch14_224.untrained': _cfg(url=''),
  1423. 'vit_gigantic_patch14_224.untrained': _cfg(url=''),
  1424. # patch models, imagenet21k (weights from official Google JAX impl), classifier not valid
  1425. 'vit_base_patch32_224.orig_in21k': _cfg(
  1426. #url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
  1427. hf_hub_id='timm/',
  1428. num_classes=0),
  1429. 'vit_base_patch16_224.orig_in21k': _cfg(
  1430. #url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
  1431. hf_hub_id='timm/',
  1432. num_classes=0),
  1433. 'vit_large_patch32_224.orig_in21k': _cfg(
  1434. #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
  1435. hf_hub_id='timm/',
  1436. num_classes=0),
  1437. 'vit_large_patch16_224.orig_in21k': _cfg(
  1438. #url='https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
  1439. hf_hub_id='timm/',
  1440. num_classes=0),
  1441. 'vit_huge_patch14_224.orig_in21k': _cfg(
  1442. hf_hub_id='timm/',
  1443. num_classes=0),
  1444. # How to train your ViT (augreg) weights, pretrained on in21k
  1445. 'vit_tiny_patch16_224.augreg_in21k': _cfg(
  1446. url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
  1447. hf_hub_id='timm/',
  1448. custom_load=True, num_classes=21843),
  1449. 'vit_small_patch32_224.augreg_in21k': _cfg(
  1450. url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
  1451. hf_hub_id='timm/',
  1452. custom_load=True, num_classes=21843),
  1453. 'vit_small_patch16_224.augreg_in21k': _cfg(
  1454. url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
  1455. hf_hub_id='timm/',
  1456. custom_load=True, num_classes=21843),
  1457. 'vit_base_patch32_224.augreg_in21k': _cfg(
  1458. url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
  1459. hf_hub_id='timm/',
  1460. custom_load=True, num_classes=21843),
  1461. 'vit_base_patch16_224.augreg_in21k': _cfg(
  1462. url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
  1463. hf_hub_id='timm/',
  1464. custom_load=True, num_classes=21843),
  1465. 'vit_base_patch8_224.augreg_in21k': _cfg(
  1466. url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
  1467. hf_hub_id='timm/',
  1468. custom_load=True, num_classes=21843),
  1469. 'vit_large_patch16_224.augreg_in21k': _cfg(
  1470. url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
  1471. hf_hub_id='timm/',
  1472. custom_load=True, num_classes=21843),
  1473. # SAM trained models (https://arxiv.org/abs/2106.01548)
  1474. 'vit_base_patch32_224.sam_in1k': _cfg(
  1475. url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz', custom_load=True,
  1476. hf_hub_id='timm/'),
  1477. 'vit_base_patch16_224.sam_in1k': _cfg(
  1478. url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz', custom_load=True,
  1479. hf_hub_id='timm/'),
  1480. # DINO pretrained - https://arxiv.org/abs/2104.14294 (no classifier head, for fine-tune only)
  1481. 'vit_small_patch16_224.dino': _cfg(
  1482. url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth',
  1483. hf_hub_id='timm/',
  1484. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  1485. 'vit_small_patch8_224.dino': _cfg(
  1486. url='https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth',
  1487. hf_hub_id='timm/',
  1488. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  1489. 'vit_base_patch16_224.dino': _cfg(
  1490. url='https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth',
  1491. hf_hub_id='timm/',
  1492. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  1493. 'vit_base_patch8_224.dino': _cfg(
  1494. url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
  1495. hf_hub_id='timm/',
  1496. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  1497. # DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune/features only)
  1498. 'vit_small_patch14_dinov2.lvd142m': _cfg(
  1499. url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth',
  1500. hf_hub_id='timm/',
  1501. license='apache-2.0',
  1502. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  1503. input_size=(3, 518, 518), crop_pct=1.0),
  1504. 'vit_base_patch14_dinov2.lvd142m': _cfg(
  1505. url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth',
  1506. hf_hub_id='timm/',
  1507. license='apache-2.0',
  1508. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  1509. input_size=(3, 518, 518), crop_pct=1.0),
  1510. 'vit_large_patch14_dinov2.lvd142m': _cfg(
  1511. url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth',
  1512. hf_hub_id='timm/',
  1513. license='apache-2.0',
  1514. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  1515. input_size=(3, 518, 518), crop_pct=1.0),
  1516. 'vit_giant_patch14_dinov2.lvd142m': _cfg(
  1517. url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth',
  1518. hf_hub_id='timm/',
  1519. license='apache-2.0',
  1520. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  1521. input_size=(3, 518, 518), crop_pct=1.0),
  1522. # DINOv2 pretrained w/ registers - https://arxiv.org/abs/2309.16588 (no classifier head, for fine-tune/features only)
  1523. 'vit_small_patch14_reg4_dinov2.lvd142m': _cfg(
  1524. url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth',
  1525. hf_hub_id='timm/',
  1526. license='apache-2.0',
  1527. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  1528. input_size=(3, 518, 518), crop_pct=1.0),
  1529. 'vit_base_patch14_reg4_dinov2.lvd142m': _cfg(
  1530. url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth',
  1531. hf_hub_id='timm/',
  1532. license='apache-2.0',
  1533. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  1534. input_size=(3, 518, 518), crop_pct=1.0),
  1535. 'vit_large_patch14_reg4_dinov2.lvd142m': _cfg(
  1536. url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth',
  1537. hf_hub_id='timm/',
  1538. license='apache-2.0',
  1539. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  1540. input_size=(3, 518, 518), crop_pct=1.0),
  1541. 'vit_giant_patch14_reg4_dinov2.lvd142m': _cfg(
  1542. url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth',
  1543. hf_hub_id='timm/',
  1544. license='apache-2.0',
  1545. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
  1546. input_size=(3, 518, 518), crop_pct=1.0),
  1547. # ViT ImageNet-21K-P pretraining by MILL
  1548. 'vit_base_patch16_224_miil.in21k': _cfg(
  1549. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
  1550. hf_hub_id='timm/',
  1551. mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221),
  1552. 'vit_base_patch16_224_miil.in21k_ft_in1k': _cfg(
  1553. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_1k_miil_84_4-2deb18e3.pth',
  1554. hf_hub_id='timm/',
  1555. mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'),
  1556. # Custom timm variants
  1557. 'vit_base_patch16_rpn_224.sw_in1k': _cfg(
  1558. url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth',
  1559. hf_hub_id='timm/'),
  1560. 'vit_medium_patch16_gap_240.sw_in12k': _cfg(
  1561. hf_hub_id='timm/',
  1562. input_size=(3, 240, 240), crop_pct=0.95, num_classes=11821),
  1563. 'vit_medium_patch16_gap_256.sw_in12k_ft_in1k': _cfg(
  1564. hf_hub_id='timm/',
  1565. input_size=(3, 256, 256), crop_pct=0.95),
  1566. 'vit_medium_patch16_gap_384.sw_in12k_ft_in1k': _cfg(
  1567. hf_hub_id='timm/',
  1568. input_size=(3, 384, 384), crop_pct=0.95, crop_mode='squash'),
  1569. 'vit_betwixt_patch16_gap_256.untrained': _cfg(
  1570. input_size=(3, 256, 256), crop_pct=0.95),
  1571. 'vit_base_patch16_gap_224.untrained': _cfg(),
  1572. # CLIP pretrained image tower and related fine-tuned weights
  1573. 'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg(
  1574. hf_hub_id='timm/',
  1575. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
  1576. 'vit_base_patch32_clip_384.laion2b_ft_in12k_in1k': _cfg(
  1577. hf_hub_id='timm/',
  1578. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 384, 384)),
  1579. 'vit_base_patch32_clip_448.laion2b_ft_in12k_in1k': _cfg(
  1580. hf_hub_id='timm/',
  1581. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 448, 448)),
  1582. 'vit_base_patch16_clip_224.laion2b_ft_in12k_in1k': _cfg(
  1583. hf_hub_id='timm/',
  1584. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
  1585. 'vit_base_patch16_clip_384.laion2b_ft_in12k_in1k': _cfg(
  1586. hf_hub_id='timm/',
  1587. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1588. crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
  1589. 'vit_large_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg(
  1590. hf_hub_id='timm/',
  1591. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
  1592. 'vit_large_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg(
  1593. hf_hub_id='timm/',
  1594. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1595. crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
  1596. 'vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg(
  1597. hf_hub_id='timm/',
  1598. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
  1599. 'vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg(
  1600. hf_hub_id='timm/',
  1601. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1602. crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
  1603. 'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg(
  1604. # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k', # FIXME weight exists, need to push
  1605. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
  1606. 'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg(
  1607. hf_hub_id='timm/',
  1608. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1609. crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
  1610. 'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg(
  1611. hf_hub_id='timm/',
  1612. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
  1613. 'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg(
  1614. hf_hub_id='timm/',
  1615. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1616. crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
  1617. 'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg(
  1618. hf_hub_id='timm/',
  1619. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
  1620. 'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg(
  1621. hf_hub_id='timm/',
  1622. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1623. crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
  1624. 'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg(
  1625. hf_hub_id='timm/',
  1626. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
  1627. 'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg(
  1628. hf_hub_id='timm/',
  1629. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
  1630. 'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg(
  1631. hf_hub_id='timm/',
  1632. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1633. crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
  1634. 'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg(
  1635. hf_hub_id='timm/',
  1636. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
  1637. 'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg(
  1638. hf_hub_id='timm/',
  1639. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
  1640. crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
  1641. 'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg(
  1642. hf_hub_id='timm/',
  1643. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
  1644. 'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg(
  1645. hf_hub_id='',
  1646. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1647. crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
  1648. 'vit_base_patch32_clip_224.openai_ft_in1k': _cfg(
  1649. hf_hub_id='timm/',
  1650. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
  1651. 'vit_base_patch16_clip_224.openai_ft_in1k': _cfg(
  1652. hf_hub_id='timm/',
  1653. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
  1654. 'vit_base_patch16_clip_384.openai_ft_in1k': _cfg(
  1655. hf_hub_id='timm/',
  1656. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1657. crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
  1658. 'vit_large_patch14_clip_224.openai_ft_in1k': _cfg(
  1659. hf_hub_id='timm/',
  1660. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
  1661. 'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg(
  1662. hf_hub_id='timm/',
  1663. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
  1664. 'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg(
  1665. hf_hub_id='timm/',
  1666. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821),
  1667. 'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg(
  1668. hf_hub_id='timm/',
  1669. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
  1670. 'vit_base_patch16_clip_224.openai_ft_in12k': _cfg(
  1671. hf_hub_id='timm/',
  1672. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
  1673. 'vit_large_patch14_clip_224.openai_ft_in12k': _cfg(
  1674. hf_hub_id='timm/',
  1675. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
  1676. 'vit_base_patch32_clip_224.laion2b': _cfg(
  1677. hf_hub_id='timm/',
  1678. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
  1679. 'vit_base_patch16_clip_224.laion2b': _cfg(
  1680. hf_hub_id='timm/',
  1681. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  1682. 'vit_large_patch14_clip_224.laion2b': _cfg(
  1683. hf_hub_id='timm/',
  1684. mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768),
  1685. 'vit_huge_patch14_clip_224.laion2b': _cfg(
  1686. hf_hub_id='timm/',
  1687. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
  1688. 'vit_giant_patch14_clip_224.laion2b': _cfg(
  1689. hf_hub_id='timm/',
  1690. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
  1691. 'vit_gigantic_patch14_clip_224.laion2b': _cfg(
  1692. hf_hub_id='timm/',
  1693. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
  1694. 'vit_base_patch32_clip_224.laion400m_e32': _cfg(
  1695. hf_hub_id='timm/',
  1696. license='mit',
  1697. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1698. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
  1699. 'vit_base_patch16_clip_224.laion400m_e32': _cfg(
  1700. hf_hub_id='timm/',
  1701. license='mit',
  1702. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  1703. 'vit_base_patch16_plus_clip_240.laion400m_e32': _cfg(
  1704. hf_hub_id='timm/',
  1705. license='mit',
  1706. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1707. input_size=(3, 240, 240), crop_pct=1.0, num_classes=640),
  1708. 'vit_large_patch14_clip_224.laion400m_e32': _cfg(
  1709. hf_hub_id='timm/',
  1710. license='mit',
  1711. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
  1712. 'vit_base_patch32_clip_224.datacompxl': _cfg(
  1713. hf_hub_id='timm/',
  1714. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  1715. 'vit_base_patch32_clip_256.datacompxl': _cfg(
  1716. hf_hub_id='timm/',
  1717. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1718. crop_pct=1.0, input_size=(3, 256, 256), num_classes=512),
  1719. 'vit_base_patch16_clip_224.datacompxl': _cfg(
  1720. hf_hub_id='timm/',
  1721. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  1722. 'vit_large_patch14_clip_224.datacompxl': _cfg(
  1723. hf_hub_id='timm/',
  1724. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
  1725. 'vit_base_patch16_clip_224.dfn2b': _cfg(
  1726. hf_hub_id='timm/',
  1727. license='apple-ascl',
  1728. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  1729. 'vit_large_patch14_clip_224.dfn2b_s39b': _cfg(
  1730. hf_hub_id='timm/',
  1731. license='apple-ascl',
  1732. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
  1733. 'vit_large_patch14_clip_224.dfn2b': _cfg(
  1734. hf_hub_id='timm/',
  1735. license='apple-ascl',
  1736. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1737. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
  1738. 'vit_huge_patch14_clip_224.dfn5b': _cfg(
  1739. hf_hub_id='timm/',
  1740. license='apple-ascl',
  1741. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1742. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
  1743. 'vit_huge_patch14_clip_378.dfn5b': _cfg(
  1744. hf_hub_id='timm/',
  1745. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1746. license='apple-ascl',
  1747. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1748. crop_pct=1.0, input_size=(3, 378, 378), num_classes=1024),
  1749. # 'vit_large_patch14_clip_224.metaclip2_worldwide': _cfg(
  1750. # hf_hub_id='timm/',
  1751. # license='cc-by-nc-4.0',
  1752. # notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1753. # mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
  1754. 'vit_huge_patch14_clip_224.metaclip2_worldwide': _cfg(
  1755. hf_hub_id='timm/',
  1756. license='cc-by-nc-4.0',
  1757. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1758. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
  1759. 'vit_huge_patch14_clip_378.metaclip2_worldwide': _cfg(
  1760. hf_hub_id='timm/',
  1761. license='cc-by-nc-4.0',
  1762. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1763. input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash', num_classes=1024),
  1764. 'vit_gigantic_patch14_clip_224.metaclip2_worldwide': _cfg(
  1765. hf_hub_id='timm/',
  1766. license='cc-by-nc-4.0',
  1767. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
  1768. 'vit_gigantic_patch14_clip_378.metaclip2_worldwide': _cfg(
  1769. hf_hub_id='timm/',
  1770. license='cc-by-nc-4.0',
  1771. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1772. input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash', num_classes=1280),
  1773. 'vit_base_patch32_clip_224.metaclip_2pt5b': _cfg(
  1774. hf_hub_id='timm/',
  1775. license='cc-by-nc-4.0',
  1776. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1777. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  1778. 'vit_base_patch16_clip_224.metaclip_2pt5b': _cfg(
  1779. hf_hub_id='timm/',
  1780. license='cc-by-nc-4.0',
  1781. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1782. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  1783. 'vit_large_patch14_clip_224.metaclip_2pt5b': _cfg(
  1784. hf_hub_id='timm/',
  1785. license='cc-by-nc-4.0',
  1786. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1787. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
  1788. 'vit_huge_patch14_clip_224.metaclip_2pt5b': _cfg(
  1789. hf_hub_id='timm/',
  1790. license='cc-by-nc-4.0',
  1791. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1792. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
  1793. 'vit_huge_patch14_clip_224.metaclip_altogether': _cfg(
  1794. hf_hub_id='timm/',
  1795. license='cc-by-nc-4.0',
  1796. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
  1797. 'vit_gigantic_patch14_clip_224.metaclip_2pt5b': _cfg(
  1798. hf_hub_id='timm/',
  1799. license='cc-by-nc-4.0',
  1800. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1801. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
  1802. 'vit_base_patch32_clip_224.metaclip_400m': _cfg(
  1803. hf_hub_id='timm/',
  1804. license='cc-by-nc-4.0',
  1805. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1806. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  1807. 'vit_base_patch16_clip_224.metaclip_400m': _cfg(
  1808. hf_hub_id='timm/',
  1809. license='cc-by-nc-4.0',
  1810. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1811. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
  1812. 'vit_large_patch14_clip_224.metaclip_400m': _cfg(
  1813. hf_hub_id='timm/',
  1814. license='cc-by-nc-4.0',
  1815. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1816. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
  1817. 'vit_base_patch32_clip_224.openai': _cfg(
  1818. hf_hub_id='timm/',
  1819. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1820. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
  1821. 'vit_base_patch16_clip_224.openai': _cfg(
  1822. hf_hub_id='timm/',
  1823. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1824. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
  1825. 'vit_large_patch14_clip_224.openai': _cfg(
  1826. hf_hub_id='timm/',
  1827. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1828. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
  1829. 'vit_large_patch14_clip_336.openai': _cfg(
  1830. hf_hub_id='timm/',
  1831. notes=('natively QuickGELU, use quickgelu model variant for original results',),
  1832. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1833. crop_pct=1.0, input_size=(3, 336, 336), num_classes=768),
  1834. 'vit_large_patch14_clip_224.apple_mclip2_dfndr2b': _cfg(
  1835. hf_hub_id='timm/',
  1836. num_classes=768,
  1837. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0,
  1838. license='apple-amlr'
  1839. ),
  1840. # experimental (may be removed)
  1841. 'vit_base_patch32_plus_256.untrained': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
  1842. 'vit_base_patch16_plus_240.untrained': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95),
  1843. 'vit_small_patch16_36x1_224.untrained': _cfg(url=''),
  1844. 'vit_small_patch16_18x2_224.untrained': _cfg(url=''),
  1845. 'vit_base_patch16_18x2_224.untrained': _cfg(url=''),
  1846. # EVA fine-tuned weights from MAE style MIM - EVA-CLIP target pretrain
  1847. # https://github.com/baaivision/EVA/blob/7ecf2c0a370d97967e86d047d7af9188f78d2df3/eva/README.md#eva-l-learning-better-mim-representations-from-eva-clip
  1848. 'eva_large_patch14_196.in22k_ft_in22k_in1k': _cfg(
  1849. # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_21k_to_1k_ft_88p6.pt',
  1850. hf_hub_id='timm/', license='mit',
  1851. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1852. input_size=(3, 196, 196), crop_pct=1.0),
  1853. 'eva_large_patch14_336.in22k_ft_in22k_in1k': _cfg(
  1854. # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_21k_to_1k_ft_89p2.pt',
  1855. hf_hub_id='timm/', license='mit',
  1856. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1857. input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
  1858. 'eva_large_patch14_196.in22k_ft_in1k': _cfg(
  1859. # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_196px_1k_ft_88p0.pt',
  1860. hf_hub_id='timm/', license='mit',
  1861. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1862. input_size=(3, 196, 196), crop_pct=1.0),
  1863. 'eva_large_patch14_336.in22k_ft_in1k': _cfg(
  1864. # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_l_psz14_336px_1k_ft_88p65.pt',
  1865. hf_hub_id='timm/', license='mit',
  1866. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
  1867. input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
  1868. 'flexivit_small.1200ep_in1k': _cfg(
  1869. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k.npz', custom_load=True,
  1870. hf_hub_id='timm/',
  1871. input_size=(3, 240, 240), crop_pct=0.95),
  1872. 'flexivit_small.600ep_in1k': _cfg(
  1873. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_600ep.npz', custom_load=True,
  1874. hf_hub_id='timm/',
  1875. input_size=(3, 240, 240), crop_pct=0.95),
  1876. 'flexivit_small.300ep_in1k': _cfg(
  1877. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_s_i1k_300ep.npz', custom_load=True,
  1878. hf_hub_id='timm/',
  1879. input_size=(3, 240, 240), crop_pct=0.95),
  1880. 'flexivit_base.1200ep_in1k': _cfg(
  1881. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k.npz', custom_load=True,
  1882. hf_hub_id='timm/',
  1883. input_size=(3, 240, 240), crop_pct=0.95),
  1884. 'flexivit_base.600ep_in1k': _cfg(
  1885. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_600ep.npz', custom_load=True,
  1886. hf_hub_id='timm/',
  1887. input_size=(3, 240, 240), crop_pct=0.95),
  1888. 'flexivit_base.300ep_in1k': _cfg(
  1889. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i1k_300ep.npz', custom_load=True,
  1890. hf_hub_id='timm/',
  1891. input_size=(3, 240, 240), crop_pct=0.95),
  1892. 'flexivit_base.1000ep_in21k': _cfg(
  1893. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_1000ep.npz', custom_load=True,
  1894. hf_hub_id='timm/',
  1895. input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
  1896. 'flexivit_base.300ep_in21k': _cfg(
  1897. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_b_i21k_300ep.npz', custom_load=True,
  1898. hf_hub_id='timm/',
  1899. input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
  1900. 'flexivit_large.1200ep_in1k': _cfg(
  1901. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k.npz', custom_load=True,
  1902. hf_hub_id='timm/',
  1903. input_size=(3, 240, 240), crop_pct=0.95),
  1904. 'flexivit_large.600ep_in1k': _cfg(
  1905. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_600ep.npz', custom_load=True,
  1906. hf_hub_id='timm/',
  1907. input_size=(3, 240, 240), crop_pct=0.95),
  1908. 'flexivit_large.300ep_in1k': _cfg(
  1909. url='https://storage.googleapis.com/big_vision/flexivit/flexivit_l_i1k_300ep.npz', custom_load=True,
  1910. hf_hub_id='timm/',
  1911. input_size=(3, 240, 240), crop_pct=0.95),
  1912. 'flexivit_base.patch16_in21k': _cfg(
  1913. url='https://storage.googleapis.com/big_vision/flexivit/vit_b16_i21k_300ep.npz', custom_load=True,
  1914. hf_hub_id='timm/',
  1915. input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
  1916. 'flexivit_base.patch30_in21k': _cfg(
  1917. url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True,
  1918. hf_hub_id='timm/',
  1919. input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
  1920. 'vit_base_patch16_xp_224.untrained': _cfg(url=''),
  1921. 'vit_large_patch14_xp_224.untrained': _cfg(url=''),
  1922. 'vit_huge_patch14_xp_224.untrained': _cfg(url=''),
  1923. 'vit_base_patch16_224.mae': _cfg(
  1924. url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth',
  1925. hf_hub_id='timm/',
  1926. license='cc-by-nc-4.0',
  1927. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  1928. 'vit_large_patch16_224.mae': _cfg(
  1929. url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_large.pth',
  1930. hf_hub_id='timm/',
  1931. license='cc-by-nc-4.0',
  1932. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  1933. 'vit_huge_patch14_224.mae': _cfg(
  1934. url='https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_huge.pth',
  1935. hf_hub_id='timm/',
  1936. license='cc-by-nc-4.0',
  1937. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  1938. 'vit_huge_patch14_gap_224.in1k_ijepa': _cfg(
  1939. url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar',
  1940. # hf_hub_id='timm/',
  1941. license='cc-by-nc-4.0',
  1942. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  1943. 'vit_huge_patch14_gap_224.in22k_ijepa': _cfg(
  1944. url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar',
  1945. # hf_hub_id='timm/',
  1946. license='cc-by-nc-4.0',
  1947. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  1948. 'vit_huge_patch16_gap_448.in1k_ijepa': _cfg(
  1949. url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar',
  1950. # hf_hub_id='timm/',
  1951. license='cc-by-nc-4.0',
  1952. input_size=(3, 448, 448), crop_pct=1.0,
  1953. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  1954. 'vit_giant_patch16_gap_224.in22k_ijepa': _cfg(
  1955. url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar',
  1956. # hf_hub_id='timm/',
  1957. license='cc-by-nc-4.0',
  1958. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
  1959. 'vit_base_patch32_siglip_256.v2_webli': _cfg(
  1960. hf_hub_id='timm/',
  1961. input_size=(3, 256, 256),
  1962. num_classes=0),
  1963. 'vit_base_patch16_siglip_224.v2_webli': _cfg(
  1964. hf_hub_id='timm/',
  1965. num_classes=0),
  1966. 'vit_base_patch16_siglip_224.webli': _cfg(
  1967. hf_hub_id='timm/',
  1968. num_classes=0),
  1969. 'vit_base_patch16_siglip_256.v2_webli': _cfg(
  1970. hf_hub_id='timm/',
  1971. input_size=(3, 256, 256),
  1972. num_classes=0),
  1973. 'vit_base_patch16_siglip_256.webli': _cfg(
  1974. hf_hub_id='timm/',
  1975. input_size=(3, 256, 256),
  1976. num_classes=0),
  1977. 'vit_base_patch16_siglip_256.webli_i18n': _cfg(
  1978. hf_hub_id='timm/',
  1979. input_size=(3, 256, 256),
  1980. num_classes=0),
  1981. 'vit_base_patch16_siglip_384.v2_webli': _cfg(
  1982. hf_hub_id='timm/',
  1983. input_size=(3, 384, 384),
  1984. num_classes=0),
  1985. 'vit_base_patch16_siglip_384.webli': _cfg(
  1986. hf_hub_id='timm/',
  1987. input_size=(3, 384, 384),
  1988. num_classes=0),
  1989. 'vit_base_patch16_siglip_512.v2_webli': _cfg(
  1990. hf_hub_id='timm/',
  1991. input_size=(3, 512, 512),
  1992. num_classes=0),
  1993. 'vit_base_patch16_siglip_512.webli': _cfg(
  1994. hf_hub_id='timm/',
  1995. input_size=(3, 512, 512),
  1996. num_classes=0),
  1997. 'vit_large_patch16_siglip_256.v2_webli': _cfg(
  1998. hf_hub_id='timm/',
  1999. input_size=(3, 256, 256),
  2000. num_classes=0),
  2001. 'vit_large_patch16_siglip_256.webli': _cfg(
  2002. hf_hub_id='timm/',
  2003. input_size=(3, 256, 256),
  2004. num_classes=0),
  2005. 'vit_large_patch16_siglip_384.v2_webli': _cfg(
  2006. hf_hub_id='timm/',
  2007. input_size=(3, 384, 384),
  2008. num_classes=0),
  2009. 'vit_large_patch16_siglip_384.webli': _cfg(
  2010. hf_hub_id='timm/',
  2011. input_size=(3, 384, 384),
  2012. num_classes=0),
  2013. 'vit_large_patch16_siglip_512.v2_webli': _cfg(
  2014. hf_hub_id='timm/',
  2015. input_size=(3, 512, 512),
  2016. num_classes=0),
  2017. 'vit_so400m_patch14_siglip_224.v2_webli': _cfg(
  2018. hf_hub_id='timm/',
  2019. num_classes=0),
  2020. 'vit_so400m_patch14_siglip_224.webli': _cfg(
  2021. hf_hub_id='timm/',
  2022. num_classes=0),
  2023. 'vit_so400m_patch14_siglip_378.v2_webli': _cfg(
  2024. hf_hub_id='timm/',
  2025. input_size=(3, 378, 378),
  2026. num_classes=0),
  2027. 'vit_so400m_patch14_siglip_378.webli': _cfg(
  2028. hf_hub_id='timm/',
  2029. input_size=(3, 378, 378),
  2030. num_classes=0),
  2031. 'vit_so400m_patch14_siglip_384.webli': _cfg(
  2032. hf_hub_id='timm/',
  2033. input_size=(3, 384, 384),
  2034. num_classes=0),
  2035. 'vit_so400m_patch16_siglip_256.v2_webli': _cfg(
  2036. hf_hub_id='timm/',
  2037. input_size=(3, 256, 256),
  2038. num_classes=0),
  2039. 'vit_so400m_patch16_siglip_256.webli_i18n': _cfg(
  2040. hf_hub_id='timm/',
  2041. input_size=(3, 256, 256),
  2042. num_classes=0),
  2043. 'vit_so400m_patch16_siglip_384.v2_webli': _cfg(
  2044. hf_hub_id='timm/',
  2045. input_size=(3, 384, 384),
  2046. num_classes=0),
  2047. 'vit_so400m_patch16_siglip_512.v2_webli': _cfg(
  2048. hf_hub_id='timm/',
  2049. input_size=(3, 512, 512),
  2050. num_classes=0),
  2051. 'vit_giantopt_patch16_siglip_256.v2_webli': _cfg(
  2052. hf_hub_id='timm/',
  2053. input_size=(3, 256, 256),
  2054. num_classes=0),
  2055. 'vit_giantopt_patch16_siglip_384.v2_webli': _cfg(
  2056. hf_hub_id='timm/',
  2057. input_size=(3, 384, 384),
  2058. num_classes=0),
  2059. 'vit_base_patch32_siglip_gap_256.v2_webli': _cfg(
  2060. hf_hub_id='timm/',
  2061. input_size=(3, 256, 256),
  2062. num_classes=0),
  2063. 'vit_base_patch16_siglip_gap_224.v2_webli': _cfg(
  2064. hf_hub_id='timm/',
  2065. num_classes=0),
  2066. 'vit_base_patch16_siglip_gap_224.webli': _cfg(
  2067. hf_hub_id='timm/',
  2068. num_classes=0),
  2069. 'vit_base_patch16_siglip_gap_256.v2_webli': _cfg(
  2070. hf_hub_id='timm/',
  2071. input_size=(3, 256, 256),
  2072. num_classes=0),
  2073. 'vit_base_patch16_siglip_gap_256.webli': _cfg(
  2074. hf_hub_id='timm/',
  2075. input_size=(3, 256, 256),
  2076. num_classes=0),
  2077. 'vit_base_patch16_siglip_gap_256.webli_i18n': _cfg(
  2078. hf_hub_id='timm/',
  2079. input_size=(3, 256, 256),
  2080. num_classes=0),
  2081. 'vit_base_patch16_siglip_gap_384.v2_webli': _cfg(
  2082. hf_hub_id='timm/',
  2083. input_size=(3, 384, 384),
  2084. num_classes=0),
  2085. 'vit_base_patch16_siglip_gap_384.webli': _cfg(
  2086. hf_hub_id='timm/',
  2087. input_size=(3, 384, 384),
  2088. num_classes=0),
  2089. 'vit_base_patch16_siglip_gap_512.v2_webli': _cfg(
  2090. hf_hub_id='timm/',
  2091. input_size=(3, 512, 512),
  2092. num_classes=0),
  2093. 'vit_base_patch16_siglip_gap_512.webli': _cfg(
  2094. hf_hub_id='timm/',
  2095. input_size=(3, 512, 512),
  2096. num_classes=0),
  2097. 'vit_large_patch16_siglip_gap_256.v2_webli': _cfg(
  2098. hf_hub_id='timm/',
  2099. input_size=(3, 256, 256),
  2100. num_classes=0),
  2101. 'vit_large_patch16_siglip_gap_256.webli': _cfg(
  2102. hf_hub_id='timm/',
  2103. input_size=(3, 256, 256),
  2104. num_classes=0),
  2105. 'vit_large_patch16_siglip_gap_384.v2_webli': _cfg(
  2106. hf_hub_id='timm/',
  2107. input_size=(3, 384, 384),
  2108. num_classes=0),
  2109. 'vit_large_patch16_siglip_gap_384.webli': _cfg(
  2110. hf_hub_id='timm/',
  2111. input_size=(3, 384, 384),
  2112. num_classes=0),
  2113. 'vit_large_patch16_siglip_gap_512.v2_webli': _cfg(
  2114. hf_hub_id='timm/',
  2115. input_size=(3, 512, 512),
  2116. num_classes=0),
  2117. 'vit_so400m_patch14_siglip_gap_224.v2_webli': _cfg(
  2118. hf_hub_id='timm/',
  2119. num_classes=0),
  2120. 'vit_so400m_patch14_siglip_gap_224.webli': _cfg(
  2121. hf_hub_id='timm/',
  2122. num_classes=0),
  2123. 'vit_so400m_patch14_siglip_gap_224.pali_mix': _cfg(
  2124. hf_hub_id='timm/',
  2125. num_classes=0),
  2126. 'vit_so400m_patch14_siglip_gap_224.pali_pt': _cfg(
  2127. hf_hub_id='timm/',
  2128. num_classes=0),
  2129. 'vit_so400m_patch14_siglip_gap_224.pali2_3b_pt': _cfg(
  2130. hf_hub_id='timm/',
  2131. num_classes=0),
  2132. 'vit_so400m_patch14_siglip_gap_224.pali2_10b_pt': _cfg(
  2133. hf_hub_id='timm/',
  2134. num_classes=0),
  2135. # 'vit_so400m_patch14_siglip_gap_224.pali2_28b_pt': _cfg(
  2136. # hf_hub_id='google/paligemma2-28b-pt-224-jax',
  2137. # hf_hub_filename='pt_27b_224.npz',
  2138. # custom_load='hf',
  2139. # num_classes=0),
  2140. 'vit_so400m_patch14_siglip_gap_378.v2_webli': _cfg(
  2141. hf_hub_id='timm/',
  2142. input_size=(3, 378, 378),
  2143. num_classes=0),
  2144. 'vit_so400m_patch14_siglip_gap_378.webli': _cfg(
  2145. hf_hub_id='timm/',
  2146. input_size=(3, 378, 378), crop_pct=1.0,
  2147. num_classes=0),
  2148. 'vit_so400m_patch14_siglip_gap_384.webli': _cfg(
  2149. hf_hub_id='timm/',
  2150. input_size=(3, 384, 384), crop_pct=1.0,
  2151. num_classes=0),
  2152. 'vit_so400m_patch14_siglip_gap_448.pali_mix': _cfg(
  2153. hf_hub_id='timm/',
  2154. input_size=(3, 448, 448), crop_pct=1.0,
  2155. num_classes=0),
  2156. 'vit_so400m_patch14_siglip_gap_448.pali_pt': _cfg(
  2157. hf_hub_id='timm/',
  2158. input_size=(3, 448, 448), crop_pct=1.0,
  2159. num_classes=0),
  2160. 'vit_so400m_patch14_siglip_gap_448.pali_refcoco_seg': _cfg(
  2161. hf_hub_id='timm/',
  2162. input_size=(3, 448, 448), crop_pct=1.0,
  2163. num_classes=0),
  2164. 'vit_so400m_patch14_siglip_gap_448.pali_ocrvqa': _cfg(
  2165. hf_hub_id='timm/',
  2166. input_size=(3, 448, 448), crop_pct=1.0,
  2167. num_classes=0),
  2168. 'vit_so400m_patch14_siglip_gap_448.pali2_3b_pt': _cfg(
  2169. hf_hub_id='timm/',
  2170. input_size=(3, 448, 448), crop_pct=1.0,
  2171. num_classes=0),
  2172. 'vit_so400m_patch14_siglip_gap_448.pali2_10b_pt': _cfg(
  2173. hf_hub_id='timm/',
  2174. input_size=(3, 448, 448), crop_pct=1.0,
  2175. num_classes=0),
  2176. # 'vit_so400m_patch14_siglip_gap_448.pali2_28b_pt': _cfg(
  2177. # hf_hub_id='google/paligemma2-28b-pt-448-jax',
  2178. # hf_hub_filename='pt_27b_448.npz',
  2179. # custom_load='hf',
  2180. # input_size=(3, 448, 448), crop_pct=1.0,
  2181. # num_classes=0),
  2182. 'vit_so400m_patch14_siglip_gap_448.pali2_3b_docci': _cfg(
  2183. hf_hub_id='timm/',
  2184. input_size=(3, 448, 448), crop_pct=1.0,
  2185. num_classes=0),
  2186. 'vit_so400m_patch14_siglip_gap_448.pali2_10b_docci': _cfg(
  2187. hf_hub_id='timm/',
  2188. input_size=(3, 448, 448), crop_pct=1.0,
  2189. num_classes=0),
  2190. 'vit_so400m_patch14_siglip_gap_896.pali_pt': _cfg(
  2191. hf_hub_id='timm/',
  2192. input_size=(3, 896, 896), crop_pct=1.0,
  2193. num_classes=0),
  2194. 'vit_so400m_patch14_siglip_gap_896.pali_refcoco_seg': _cfg(
  2195. hf_hub_id='timm/',
  2196. input_size=(3, 896, 896), crop_pct=1.0,
  2197. num_classes=0),
  2198. 'vit_so400m_patch14_siglip_gap_896.pali_ocrvqa': _cfg(
  2199. hf_hub_id='timm/',
  2200. input_size=(3, 896, 896), crop_pct=1.0,
  2201. num_classes=0),
  2202. 'vit_so400m_patch14_siglip_gap_896.pali2_3b_pt': _cfg(
  2203. hf_hub_id='timm/',
  2204. input_size=(3, 896, 896), crop_pct=1.0,
  2205. num_classes=0),
  2206. 'vit_so400m_patch14_siglip_gap_896.pali2_10b_pt': _cfg(
  2207. hf_hub_id='timm/',
  2208. input_size=(3, 896, 896), crop_pct=1.0,
  2209. num_classes=0),
  2210. # 'vit_so400m_patch14_siglip_gap_896.pali2_28b_pt': _cfg(
  2211. # hf_hub_id='google/paligemma2-28b-pt-896-jax',
  2212. # hf_hub_filename='pt_27b_896.npz',
  2213. # custom_load='hf',
  2214. # input_size=(3, 896, 896), crop_pct=1.0,
  2215. # num_classes=0),
  2216. 'vit_so400m_patch16_siglip_gap_256.v2_webli': _cfg(
  2217. hf_hub_id='timm/',
  2218. input_size=(3, 256, 256),
  2219. num_classes=0),
  2220. 'vit_so400m_patch16_siglip_gap_256.webli_i18n': _cfg(
  2221. hf_hub_id='timm/',
  2222. input_size=(3, 256, 256),
  2223. num_classes=0),
  2224. 'vit_so400m_patch16_siglip_gap_384.v2_webli': _cfg(
  2225. hf_hub_id='timm/',
  2226. input_size=(3, 384, 384),
  2227. num_classes=0),
  2228. 'vit_so400m_patch16_siglip_gap_512.v2_webli': _cfg(
  2229. hf_hub_id='timm/',
  2230. input_size=(3, 512, 512),
  2231. num_classes=0),
  2232. 'vit_giantopt_patch16_siglip_gap_256.v2_webli': _cfg(
  2233. hf_hub_id='timm/',
  2234. input_size=(3, 256, 256),
  2235. num_classes=0),
  2236. 'vit_giantopt_patch16_siglip_gap_384.v2_webli': _cfg(
  2237. hf_hub_id='timm/',
  2238. input_size=(3, 384, 384),
  2239. num_classes=0),
  2240. 'vit_so400m_patch14_siglip_378.webli_ft_in1k': _cfg(
  2241. hf_hub_id='timm/',
  2242. input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash',
  2243. ),
  2244. 'vit_so400m_patch14_siglip_gap_378.webli_ft_in1k': _cfg(
  2245. hf_hub_id='timm/',
  2246. input_size=(3, 378, 378), crop_pct=1.0, crop_mode='squash',
  2247. ),
  2248. 'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg(
  2249. hf_hub_id='timm/',
  2250. license='mit',
  2251. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
  2252. 'vit_medium_patch32_clip_224.tinyclip_laion400m': _cfg(
  2253. hf_hub_id='timm/',
  2254. license='mit',
  2255. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
  2256. 'vit_medium_patch16_clip_224.tinyclip_yfcc15m': _cfg(
  2257. hf_hub_id='timm/',
  2258. license='mit',
  2259. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
  2260. 'vit_betwixt_patch32_clip_224.tinyclip_laion400m': _cfg(
  2261. hf_hub_id='timm/',
  2262. license='mit',
  2263. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
  2264. 'vit_wee_patch16_reg1_gap_256.sbb_in1k': _cfg(
  2265. hf_hub_id='timm/',
  2266. input_size=(3, 256, 256), crop_pct=0.95),
  2267. 'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
  2268. hf_hub_id='timm/',
  2269. input_size=(3, 256, 256), crop_pct=0.95),
  2270. 'vit_little_patch16_reg1_gap_256.sbb_in12k_ft_in1k': _cfg(
  2271. hf_hub_id='timm/',
  2272. input_size=(3, 256, 256), crop_pct=0.95),
  2273. 'vit_little_patch16_reg1_gap_256.sbb_in12k': _cfg(
  2274. hf_hub_id='timm/',
  2275. num_classes=11821,
  2276. input_size=(3, 256, 256), crop_pct=0.95),
  2277. 'vit_little_patch16_reg4_gap_256.sbb_in1k': _cfg(
  2278. hf_hub_id='timm/',
  2279. input_size=(3, 256, 256), crop_pct=0.95),
  2280. 'vit_medium_patch16_reg1_gap_256.sbb_in1k': _cfg(
  2281. hf_hub_id='timm/',
  2282. input_size=(3, 256, 256), crop_pct=0.95),
  2283. 'vit_medium_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
  2284. hf_hub_id='timm/',
  2285. input_size=(3, 256, 256), crop_pct=0.95),
  2286. 'vit_medium_patch16_reg4_gap_256.sbb_in1k': _cfg(
  2287. hf_hub_id='timm/',
  2288. input_size=(3, 256, 256), crop_pct=0.95),
  2289. 'vit_medium_patch16_reg4_gap_256.sbb_in12k': _cfg(
  2290. hf_hub_id='timm/',
  2291. num_classes=11821,
  2292. input_size=(3, 256, 256), crop_pct=0.95),
  2293. 'vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k': _cfg(
  2294. hf_hub_id='timm/',
  2295. input_size=(3, 256, 256), crop_pct=0.95),
  2296. 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
  2297. hf_hub_id='timm/',
  2298. input_size=(3, 256, 256), crop_pct=0.95),
  2299. 'vit_mediumd_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg(
  2300. hf_hub_id='timm/',
  2301. num_classes=11821,
  2302. input_size=(3, 256, 256), crop_pct=0.95),
  2303. 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg(
  2304. hf_hub_id='timm/',
  2305. num_classes=11821,
  2306. input_size=(3, 256, 256), crop_pct=0.95),
  2307. 'vit_mediumd_patch16_reg4_gap_384.sbb2_e200_in12k_ft_in1k': _cfg(
  2308. hf_hub_id='timm/',
  2309. input_size=(3, 384, 384), crop_pct=1.0),
  2310. 'vit_betwixt_patch16_reg1_gap_256.sbb_in1k': _cfg(
  2311. hf_hub_id='timm/',
  2312. input_size=(3, 256, 256), crop_pct=0.95),
  2313. 'vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k_ft_in1k': _cfg(
  2314. hf_hub_id='timm/',
  2315. input_size=(3, 256, 256), crop_pct=0.95),
  2316. 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
  2317. hf_hub_id='timm/',
  2318. input_size=(3, 256, 256), crop_pct=0.95),
  2319. 'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg(
  2320. hf_hub_id='timm/',
  2321. input_size=(3, 256, 256), crop_pct=0.95),
  2322. 'vit_betwixt_patch16_reg4_gap_256.sbb2_e200_in12k': _cfg(
  2323. hf_hub_id='timm/',
  2324. num_classes=11821,
  2325. input_size=(3, 256, 256), crop_pct=0.95),
  2326. 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg(
  2327. hf_hub_id='timm/',
  2328. num_classes=11821,
  2329. input_size=(3, 256, 256), crop_pct=0.95),
  2330. 'vit_betwixt_patch16_reg4_gap_384.sbb2_e200_in12k_ft_in1k': _cfg(
  2331. hf_hub_id='timm/',
  2332. input_size=(3, 384, 384), crop_pct=1.0),
  2333. 'vit_base_patch16_reg4_gap_256.untrained': _cfg(
  2334. input_size=(3, 256, 256)),
  2335. 'vit_so150m_patch16_reg4_gap_256.sbb_e250_in12k_ft_in1k': _cfg(
  2336. hf_hub_id='timm/',
  2337. input_size=(3, 256, 256), crop_pct=0.95),
  2338. 'vit_so150m_patch16_reg4_gap_256.sbb_e250_in12k': _cfg(
  2339. hf_hub_id='timm/',
  2340. num_classes=11821,
  2341. input_size=(3, 256, 256), crop_pct=0.95),
  2342. 'vit_so150m_patch16_reg4_gap_384.sbb_e250_in12k_ft_in1k': _cfg(
  2343. hf_hub_id='timm/',
  2344. input_size=(3, 384, 384), crop_pct=1.0),
  2345. 'vit_so150m_patch16_reg4_map_256.untrained': _cfg(
  2346. input_size=(3, 256, 256)),
  2347. 'vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k_ft_in1k': _cfg(
  2348. hf_hub_id='timm/',
  2349. input_size=(3, 256, 256), crop_pct=1.0),
  2350. 'vit_so150m2_patch16_reg1_gap_256.sbb_e200_in12k': _cfg(
  2351. hf_hub_id='timm/',
  2352. num_classes=11821,
  2353. input_size=(3, 256, 256), crop_pct=1.0),
  2354. 'vit_so150m2_patch16_reg1_gap_384.sbb_e200_in12k_ft_in1k': _cfg(
  2355. hf_hub_id='timm/',
  2356. input_size=(3, 384, 384), crop_pct=1.0),
  2357. 'vit_so150m2_patch16_reg1_gap_448.sbb_e200_in12k_ft_in1k': _cfg(
  2358. hf_hub_id='timm/',
  2359. input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash'),
  2360. 'vit_intern300m_patch14_448.ogvl_dist': _cfg(
  2361. hf_hub_id='timm/',
  2362. license='mit',
  2363. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  2364. input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
  2365. ),
  2366. 'vit_intern300m_patch14_448.ogvl_2pt5': _cfg(
  2367. hf_hub_id='timm/',
  2368. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  2369. input_size=(3, 448, 448), crop_pct=1.0, num_classes=0,
  2370. ),
  2371. 'aimv2_large_patch14_224.apple_pt': _cfg(
  2372. hf_hub_id='timm/',
  2373. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2374. crop_pct=1.0, num_classes=0),
  2375. 'aimv2_large_patch14_224.apple_pt_dist': _cfg(
  2376. hf_hub_id='timm/',
  2377. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2378. crop_pct=1.0, num_classes=0),
  2379. 'aimv2_huge_patch14_224.apple_pt': _cfg(
  2380. hf_hub_id='timm/',
  2381. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2382. crop_pct=1.0, num_classes=0),
  2383. 'aimv2_1b_patch14_224.apple_pt': _cfg(
  2384. hf_hub_id='timm/',
  2385. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2386. crop_pct=1.0, num_classes=0),
  2387. 'aimv2_3b_patch14_224.apple_pt': _cfg(
  2388. hf_hub_id='timm/',
  2389. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2390. crop_pct=1.0, num_classes=0),
  2391. 'aimv2_large_patch14_336.apple_pt': _cfg(
  2392. hf_hub_id='timm/',
  2393. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2394. input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
  2395. 'aimv2_large_patch14_336.apple_pt_dist': _cfg(
  2396. hf_hub_id='timm/',
  2397. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2398. input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
  2399. 'aimv2_huge_patch14_336.apple_pt': _cfg(
  2400. hf_hub_id='timm/',
  2401. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2402. input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
  2403. 'aimv2_1b_patch14_336.apple_pt': _cfg(
  2404. hf_hub_id='timm/',
  2405. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2406. input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
  2407. 'aimv2_3b_patch14_336.apple_pt': _cfg(
  2408. hf_hub_id='timm/',
  2409. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2410. input_size=(3, 336, 336), crop_pct=1.0, num_classes=0),
  2411. 'aimv2_large_patch14_448.apple_pt': _cfg(
  2412. hf_hub_id='timm/',
  2413. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2414. input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
  2415. 'aimv2_huge_patch14_448.apple_pt': _cfg(
  2416. hf_hub_id='timm/',
  2417. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2418. input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
  2419. 'aimv2_1b_patch14_448.apple_pt': _cfg(
  2420. hf_hub_id='timm/',
  2421. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2422. input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
  2423. 'aimv2_3b_patch14_448.apple_pt': _cfg(
  2424. hf_hub_id='timm/',
  2425. mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, license='apple-ascl',
  2426. input_size=(3, 448, 448), crop_pct=1.0, num_classes=0),
  2427. 'test_vit.r160_in1k': _cfg(
  2428. hf_hub_id='timm/',
  2429. input_size=(3, 160, 160), crop_pct=0.95),
  2430. 'test_vit2.r160_in1k': _cfg(
  2431. hf_hub_id='timm/',
  2432. input_size=(3, 160, 160), crop_pct=0.95),
  2433. 'test_vit3.r160_in1k': _cfg(
  2434. hf_hub_id='timm/',
  2435. input_size=(3, 160, 160), crop_pct=0.95),
  2436. 'test_vit4.r160_in1k': _cfg(
  2437. input_size=(3, 160, 160), crop_pct=0.95),
  2438. # BEiT3 models (remapped to VisionTransformer with scale_attn_norm=True, scale_mlp_norm=True)
  2439. 'beit3_base_patch16_224.in22k_ft_in1k': _cfg(
  2440. hf_hub_id='timm/',
  2441. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
  2442. 'beit3_base_patch16_224.indomain_in22k_ft_in1k': _cfg(
  2443. hf_hub_id='timm/',
  2444. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
  2445. 'beit3_large_patch16_224.in22k_ft_in1k': _cfg(
  2446. hf_hub_id='timm/',
  2447. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
  2448. 'beit3_large_patch16_224.indomain_in22k_ft_in1k': _cfg(
  2449. hf_hub_id='timm/',
  2450. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
  2451. 'beit3_giant_patch14_224.untrained': _cfg(
  2452. url='', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
  2453. 'beit3_giant_patch14_336.untrained': _cfg(
  2454. url='', input_size=(3, 336, 336), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0),
  2455. 'beit3_base_patch16_224.pt': _cfg(
  2456. hf_hub_id='timm/',
  2457. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0,
  2458. num_classes=0,
  2459. ),
  2460. 'beit3_base_patch16_224.indomain_pt': _cfg(
  2461. hf_hub_id='timm/',
  2462. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0,
  2463. num_classes=0,
  2464. ),
  2465. 'beit3_large_patch16_224.pt': _cfg(
  2466. hf_hub_id='timm/',
  2467. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0,
  2468. num_classes=0,
  2469. ),
  2470. 'beit3_large_patch16_224.indomain_pt': _cfg(
  2471. hf_hub_id='timm/',
  2472. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=1.0,
  2473. num_classes=0,
  2474. ),
  2475. }
  2476. _quick_gelu_cfgs = [n for n, c in default_cfgs.items() if c.get('notes', ()) and 'quickgelu' in c['notes'][0]]
  2477. for n in _quick_gelu_cfgs:
  2478. # generate quickgelu default cfgs based on contents of notes field
  2479. c = copy.deepcopy(default_cfgs[n])
  2480. if c['hf_hub_id'] == 'timm/':
  2481. c['hf_hub_id'] = 'timm/' + n # need to use non-quickgelu model name for hub id
  2482. default_cfgs[n.replace('_clip_', '_clip_quickgelu_')] = c
  2483. default_cfgs = generate_default_cfgs(default_cfgs)
  2484. # Global flag to use NaFlexVit instead of VisionTransformer
  2485. _USE_NAFLEX_DEFAULT = os.environ.get('TIMM_USE_NAFLEXVIT', 'false').lower() == 'true'
  2486. def _create_vision_transformer(
  2487. variant: str,
  2488. pretrained: bool = False,
  2489. use_naflex: Optional[bool] = None,
  2490. **kwargs,
  2491. ) -> Union[VisionTransformer, 'NaFlexVit']:
  2492. # Check if we should use NaFlexVit instead
  2493. if use_naflex is None:
  2494. use_naflex = _USE_NAFLEX_DEFAULT
  2495. if use_naflex:
  2496. # Import here to avoid circular imports
  2497. from .naflexvit import _create_naflexvit_from_classic
  2498. return _create_naflexvit_from_classic(variant, pretrained, **kwargs)
  2499. out_indices = kwargs.pop('out_indices', 3)
  2500. if 'flexi' in variant:
  2501. # FIXME Google FlexiViT pretrained models have a strong preference for bilinear patch / embed
  2502. # interpolation, other pretrained models resize better w/ anti-aliased bicubic interpolation.
  2503. _filter_fn = partial(checkpoint_filter_fn, interpolation='bilinear', antialias=False)
  2504. else:
  2505. _filter_fn = checkpoint_filter_fn
  2506. # FIXME attn pool (currently only in siglip) params removed if pool disabled, is there a better soln?
  2507. strict = kwargs.pop('pretrained_strict', True)
  2508. if 'siglip' in variant and kwargs.get('global_pool', None) != 'map':
  2509. strict = False
  2510. return build_model_with_cfg(
  2511. VisionTransformer,
  2512. variant,
  2513. pretrained,
  2514. pretrained_filter_fn=_filter_fn,
  2515. pretrained_strict=strict,
  2516. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  2517. **kwargs,
  2518. )
  2519. @register_model
  2520. def vit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2521. """ ViT-Tiny (Vit-Ti/16)
  2522. """
  2523. model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
  2524. model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2525. return model
  2526. @register_model
  2527. def vit_tiny_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2528. """ ViT-Tiny (Vit-Ti/16) @ 384x384.
  2529. """
  2530. model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
  2531. model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  2532. return model
  2533. @register_model
  2534. def vit_small_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2535. """ ViT-Small (ViT-S/32)
  2536. """
  2537. model_args = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6)
  2538. model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2539. return model
  2540. @register_model
  2541. def vit_small_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2542. """ ViT-Small (ViT-S/32) at 384x384.
  2543. """
  2544. model_args = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6)
  2545. model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **dict(model_args, **kwargs))
  2546. return model
  2547. @register_model
  2548. def vit_small_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2549. """ ViT-Small (ViT-S/16)
  2550. """
  2551. model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
  2552. model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2553. return model
  2554. @register_model
  2555. def vit_small_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2556. """ ViT-Small (ViT-S/16)
  2557. """
  2558. model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
  2559. model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  2560. return model
  2561. @register_model
  2562. def vit_small_patch8_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2563. """ ViT-Small (ViT-S/8)
  2564. """
  2565. model_args = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6)
  2566. model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2567. return model
  2568. @register_model
  2569. def vit_base_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2570. """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
  2571. ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer.
  2572. """
  2573. model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
  2574. model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2575. return model
  2576. @register_model
  2577. def vit_base_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2578. """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
  2579. ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
  2580. """
  2581. model_args = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
  2582. model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **dict(model_args, **kwargs))
  2583. return model
  2584. @register_model
  2585. def vit_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2586. """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
  2587. ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
  2588. """
  2589. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
  2590. model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2591. return model
  2592. @register_model
  2593. def vit_base_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2594. """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
  2595. ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
  2596. """
  2597. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
  2598. model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  2599. return model
  2600. @register_model
  2601. def vit_base_patch8_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2602. """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
  2603. ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
  2604. """
  2605. model_args = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12)
  2606. model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2607. return model
  2608. @register_model
  2609. def vit_large_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2610. """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
  2611. """
  2612. model_args = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16)
  2613. model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2614. return model
  2615. @register_model
  2616. def vit_large_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2617. """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
  2618. ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
  2619. """
  2620. model_args = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16)
  2621. model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **dict(model_args, **kwargs))
  2622. return model
  2623. @register_model
  2624. def vit_large_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2625. """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
  2626. ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
  2627. """
  2628. model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
  2629. model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2630. return model
  2631. @register_model
  2632. def vit_large_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2633. """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
  2634. ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
  2635. """
  2636. model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
  2637. model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  2638. return model
  2639. @register_model
  2640. def vit_large_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2641. """ ViT-Large model (ViT-L/14)
  2642. """
  2643. model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16)
  2644. model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2645. return model
  2646. @register_model
  2647. def vit_huge_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2648. """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
  2649. """
  2650. model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16)
  2651. model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2652. return model
  2653. @register_model
  2654. def vit_giant_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2655. """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
  2656. """
  2657. model_args = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16)
  2658. model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2659. return model
  2660. @register_model
  2661. def vit_gigantic_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2662. """ ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
  2663. """
  2664. model_args = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16)
  2665. model = _create_vision_transformer(
  2666. 'vit_gigantic_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2667. return model
  2668. @register_model
  2669. def vit_base_patch16_224_miil(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2670. """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
  2671. Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
  2672. """
  2673. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False)
  2674. model = _create_vision_transformer(
  2675. 'vit_base_patch16_224_miil', pretrained=pretrained, **dict(model_args, **kwargs))
  2676. return model
  2677. @register_model
  2678. def vit_medium_patch16_gap_240(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2679. """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 240x240
  2680. """
  2681. model_args = dict(
  2682. patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
  2683. global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
  2684. model = _create_vision_transformer(
  2685. 'vit_medium_patch16_gap_240', pretrained=pretrained, **dict(model_args, **kwargs))
  2686. return model
  2687. @register_model
  2688. def vit_medium_patch16_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2689. """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 256x256
  2690. """
  2691. model_args = dict(
  2692. patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
  2693. global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
  2694. model = _create_vision_transformer(
  2695. 'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  2696. return model
  2697. @register_model
  2698. def vit_medium_patch16_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2699. """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 384x384
  2700. """
  2701. model_args = dict(
  2702. patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
  2703. global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
  2704. model = _create_vision_transformer(
  2705. 'vit_medium_patch16_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  2706. return model
  2707. @register_model
  2708. def vit_betwixt_patch16_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2709. """ ViT-Betwixt (ViT-b/16) w/o class token, w/ avg-pool @ 256x256
  2710. """
  2711. model_args = dict(
  2712. patch_size=16, embed_dim=640, depth=12, num_heads=10, class_token=False,
  2713. global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
  2714. model = _create_vision_transformer(
  2715. 'vit_betwixt_patch16_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  2716. return model
  2717. @register_model
  2718. def vit_base_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2719. """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224
  2720. """
  2721. model_args = dict(
  2722. patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
  2723. model = _create_vision_transformer(
  2724. 'vit_base_patch16_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2725. return model
  2726. @register_model
  2727. def vit_huge_patch14_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2728. """ ViT-Huge model (ViT-H/14) w/ no class token, avg pool
  2729. """
  2730. model_args = dict(
  2731. patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
  2732. model = _create_vision_transformer(
  2733. 'vit_huge_patch14_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2734. return model
  2735. @register_model
  2736. def vit_huge_patch16_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2737. """ ViT-Huge model (ViT-H/16) w/ no class token, avg pool @ 448x448
  2738. """
  2739. model_args = dict(
  2740. patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
  2741. model = _create_vision_transformer(
  2742. 'vit_huge_patch16_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
  2743. return model
  2744. @register_model
  2745. def vit_giant_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2746. """ ViT-Giant (little-gg) model (ViT-g/16) w/ no class token, avg pool
  2747. """
  2748. model_args = dict(
  2749. patch_size=16, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
  2750. class_token=False, global_pool='avg', fc_norm=False)
  2751. model = _create_vision_transformer(
  2752. 'vit_giant_patch16_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2753. return model
  2754. @register_model
  2755. def vit_xsmall_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2756. # TinyCLIP 8M
  2757. model_args = dict(embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  2758. model = _create_vision_transformer(
  2759. 'vit_xsmall_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2760. return model
  2761. @register_model
  2762. def vit_medium_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2763. # TinyCLIP 40M
  2764. model_args = dict(
  2765. patch_size=32, embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  2766. model = _create_vision_transformer(
  2767. 'vit_medium_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2768. return model
  2769. @register_model
  2770. def vit_medium_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2771. # TinyCLIP 39M
  2772. model_args = dict(embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  2773. model = _create_vision_transformer(
  2774. 'vit_medium_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2775. return model
  2776. @register_model
  2777. def vit_betwixt_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2778. # TinyCLIP 61M
  2779. model_args = dict(
  2780. patch_size=32, embed_dim=640, depth=12, num_heads=10, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  2781. model = _create_vision_transformer(
  2782. 'vit_betwixt_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2783. return model
  2784. @register_model
  2785. def vit_base_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2786. """ ViT-B/32 CLIP image tower @ 224x224
  2787. """
  2788. model_args = dict(
  2789. patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  2790. model = _create_vision_transformer(
  2791. 'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2792. return model
  2793. @register_model
  2794. def vit_base_patch32_clip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2795. """ ViT-B/32 CLIP image tower @ 256x256
  2796. """
  2797. model_args = dict(
  2798. patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  2799. model = _create_vision_transformer(
  2800. 'vit_base_patch32_clip_256', pretrained=pretrained, **dict(model_args, **kwargs))
  2801. return model
  2802. @register_model
  2803. def vit_base_patch32_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2804. """ ViT-B/32 CLIP image tower @ 384x384
  2805. """
  2806. model_args = dict(
  2807. patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  2808. model = _create_vision_transformer(
  2809. 'vit_base_patch32_clip_384', pretrained=pretrained, **dict(model_args, **kwargs))
  2810. return model
  2811. @register_model
  2812. def vit_base_patch32_clip_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2813. """ ViT-B/32 CLIP image tower @ 448x448
  2814. """
  2815. model_args = dict(
  2816. patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  2817. model = _create_vision_transformer(
  2818. 'vit_base_patch32_clip_448', pretrained=pretrained, **dict(model_args, **kwargs))
  2819. return model
  2820. @register_model
  2821. def vit_base_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2822. """ ViT-B/16 CLIP image tower
  2823. """
  2824. model_args = dict(
  2825. patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  2826. model = _create_vision_transformer(
  2827. 'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2828. return model
  2829. @register_model
  2830. def vit_base_patch16_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2831. """ ViT-B/16 CLIP image tower @ 384x384
  2832. """
  2833. model_args = dict(
  2834. patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  2835. model = _create_vision_transformer(
  2836. 'vit_base_patch16_clip_384', pretrained=pretrained, **dict(model_args, **kwargs))
  2837. return model
  2838. @register_model
  2839. def vit_base_patch16_plus_clip_240(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2840. """ ViT-Base (ViT-B/16+) CLIP image tower @ 240x240
  2841. """
  2842. model_args = dict(
  2843. patch_size=16, embed_dim=896, depth=12, num_heads=14, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  2844. model = _create_vision_transformer(
  2845. 'vit_base_patch16_plus_clip_240', pretrained=pretrained, **dict(model_args, **kwargs))
  2846. return model
  2847. @register_model
  2848. def vit_large_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2849. """ ViT-Large model (ViT-L/14) CLIP image tower
  2850. """
  2851. model_args = dict(
  2852. patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  2853. model = _create_vision_transformer(
  2854. 'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2855. return model
  2856. @register_model
  2857. def vit_large_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2858. """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336
  2859. """
  2860. model_args = dict(
  2861. patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  2862. model = _create_vision_transformer(
  2863. 'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs))
  2864. return model
  2865. @register_model
  2866. def vit_huge_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2867. """ ViT-Huge model (ViT-H/14) CLIP image tower.
  2868. """
  2869. model_args = dict(
  2870. patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  2871. model = _create_vision_transformer(
  2872. 'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2873. return model
  2874. @register_model
  2875. def vit_huge_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2876. """ ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336
  2877. """
  2878. model_args = dict(
  2879. patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  2880. model = _create_vision_transformer(
  2881. 'vit_huge_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs))
  2882. return model
  2883. @register_model
  2884. def vit_huge_patch14_clip_378(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2885. """ ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378
  2886. """
  2887. model_args = dict(
  2888. patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=partial(LayerNorm, eps=1e-5))
  2889. model = _create_vision_transformer(
  2890. 'vit_huge_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs))
  2891. return model
  2892. @register_model
  2893. def vit_giant_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2894. """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
  2895. Pretrained weights from CLIP image tower.
  2896. """
  2897. model_args = dict(
  2898. patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True,
  2899. norm_layer=partial(LayerNorm, eps=1e-5),
  2900. )
  2901. model = _create_vision_transformer(
  2902. 'vit_giant_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2903. return model
  2904. @register_model
  2905. def vit_gigantic_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2906. """ ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
  2907. Pretrained weights from CLIP image tower.
  2908. """
  2909. model_args = dict(
  2910. patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True,
  2911. norm_layer=partial(LayerNorm, eps=1e-5),
  2912. )
  2913. model = _create_vision_transformer(
  2914. 'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2915. return model
  2916. @register_model
  2917. def vit_gigantic_patch14_clip_378(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2918. """ ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
  2919. Pretrained weights from CLIP image tower.
  2920. """
  2921. model_args = dict(
  2922. patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True,
  2923. norm_layer=partial(LayerNorm, eps=1e-5),
  2924. )
  2925. model = _create_vision_transformer(
  2926. 'vit_gigantic_patch14_clip_378', pretrained=pretrained, **dict(model_args, **kwargs))
  2927. return model
  2928. @register_model
  2929. def vit_base_patch32_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2930. """ ViT-B/32 CLIP image tower @ 224x224
  2931. """
  2932. model_args = dict(
  2933. patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
  2934. norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
  2935. )
  2936. model = _create_vision_transformer(
  2937. 'vit_base_patch32_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2938. return model
  2939. @register_model
  2940. def vit_base_patch16_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2941. """ ViT-B/16 CLIP image tower w/ QuickGELU act
  2942. """
  2943. model_args = dict(
  2944. patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True,
  2945. norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
  2946. )
  2947. model = _create_vision_transformer(
  2948. 'vit_base_patch16_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2949. return model
  2950. @register_model
  2951. def vit_large_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2952. """ ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act
  2953. """
  2954. model_args = dict(
  2955. patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
  2956. norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
  2957. )
  2958. model = _create_vision_transformer(
  2959. 'vit_large_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2960. return model
  2961. @register_model
  2962. def vit_large_patch14_clip_quickgelu_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2963. """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 w/ QuickGELU act
  2964. """
  2965. model_args = dict(
  2966. patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True,
  2967. norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
  2968. )
  2969. model = _create_vision_transformer(
  2970. 'vit_large_patch14_clip_quickgelu_336', pretrained=pretrained, **dict(model_args, **kwargs))
  2971. return model
  2972. @register_model
  2973. def vit_huge_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2974. """ ViT-Huge model (ViT-H/14) CLIP image tower w/ QuickGELU act.
  2975. """
  2976. model_args = dict(
  2977. patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
  2978. norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
  2979. )
  2980. model = _create_vision_transformer(
  2981. 'vit_huge_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2982. return model
  2983. @register_model
  2984. def vit_huge_patch14_clip_quickgelu_378(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2985. """ ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 w/ QuickGELU act
  2986. """
  2987. model_args = dict(
  2988. patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True,
  2989. norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
  2990. )
  2991. model = _create_vision_transformer(
  2992. 'vit_huge_patch14_clip_quickgelu_378', pretrained=pretrained, **dict(model_args, **kwargs))
  2993. return model
  2994. @register_model
  2995. def vit_gigantic_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  2996. """ ViT-bigG model (ViT-G/14) w/ QuickGELU act
  2997. """
  2998. model_args = dict(
  2999. patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True,
  3000. norm_layer=partial(LayerNorm, eps=1e-5), act_layer='quick_gelu'
  3001. )
  3002. model = _create_vision_transformer(
  3003. 'vit_gigantic_patch14_clip_quickgelu_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3004. return model
  3005. # Experimental models below
  3006. @register_model
  3007. def vit_base_patch32_plus_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3008. """ ViT-Base (ViT-B/32+)
  3009. """
  3010. model_args = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5)
  3011. model = _create_vision_transformer(
  3012. 'vit_base_patch32_plus_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3013. return model
  3014. @register_model
  3015. def vit_base_patch16_plus_240(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3016. """ ViT-Base (ViT-B/16+)
  3017. """
  3018. model_args = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5)
  3019. model = _create_vision_transformer(
  3020. 'vit_base_patch16_plus_240', pretrained=pretrained, **dict(model_args, **kwargs))
  3021. return model
  3022. @register_model
  3023. def vit_base_patch16_rpn_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3024. """ ViT-Base (ViT-B/16) w/ residual post-norm
  3025. """
  3026. model_args = dict(
  3027. patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
  3028. class_token=False, block_fn=ResPostBlock, global_pool='avg')
  3029. model = _create_vision_transformer(
  3030. 'vit_base_patch16_rpn_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3031. return model
  3032. @register_model
  3033. def vit_small_patch16_36x1_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3034. """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove.
  3035. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
  3036. Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
  3037. """
  3038. model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5)
  3039. model = _create_vision_transformer(
  3040. 'vit_small_patch16_36x1_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3041. return model
  3042. @register_model
  3043. def vit_small_patch16_18x2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3044. """ ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
  3045. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
  3046. Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
  3047. """
  3048. model_args = dict(
  3049. patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelThingsBlock)
  3050. model = _create_vision_transformer(
  3051. 'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3052. return model
  3053. @register_model
  3054. def vit_base_patch16_18x2_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3055. """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
  3056. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
  3057. """
  3058. model_args = dict(
  3059. patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelThingsBlock)
  3060. model = _create_vision_transformer(
  3061. 'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3062. return model
  3063. @register_model
  3064. def eva_large_patch14_196(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3065. """ EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain"""
  3066. model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg')
  3067. model = _create_vision_transformer(
  3068. 'eva_large_patch14_196', pretrained=pretrained, **dict(model_args, **kwargs))
  3069. return model
  3070. @register_model
  3071. def eva_large_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3072. """ EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain"""
  3073. model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg')
  3074. model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  3075. return model
  3076. @register_model
  3077. def flexivit_small(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3078. """ FlexiViT-Small
  3079. """
  3080. model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True)
  3081. model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **dict(model_args, **kwargs))
  3082. return model
  3083. @register_model
  3084. def flexivit_base(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3085. """ FlexiViT-Base
  3086. """
  3087. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
  3088. model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **dict(model_args, **kwargs))
  3089. return model
  3090. @register_model
  3091. def flexivit_large(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3092. """ FlexiViT-Large
  3093. """
  3094. model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True)
  3095. model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_args, **kwargs))
  3096. return model
  3097. @register_model
  3098. def vit_base_patch16_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3099. """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
  3100. """
  3101. model_args = dict(
  3102. patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, no_embed_class=True,
  3103. norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
  3104. )
  3105. model = _create_vision_transformer(
  3106. 'vit_base_patch16_xp_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3107. return model
  3108. @register_model
  3109. def vit_large_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3110. """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
  3111. """
  3112. model_args = dict(
  3113. patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, no_embed_class=True,
  3114. norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
  3115. )
  3116. model = _create_vision_transformer(
  3117. 'vit_large_patch14_xp_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3118. return model
  3119. @register_model
  3120. def vit_huge_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3121. """ ViT-Huge model (ViT-H/14) w/ parallel blocks and qk norm enabled.
  3122. """
  3123. model_args = dict(
  3124. patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, no_embed_class=True,
  3125. norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
  3126. )
  3127. model = _create_vision_transformer(
  3128. 'vit_huge_patch14_xp_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3129. return model
  3130. @register_model
  3131. def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3132. """ ViT-S/14 for DINOv2
  3133. """
  3134. model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5)
  3135. model = _create_vision_transformer(
  3136. 'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
  3137. return model
  3138. @register_model
  3139. def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3140. """ ViT-B/14 for DINOv2
  3141. """
  3142. model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5)
  3143. model = _create_vision_transformer(
  3144. 'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
  3145. return model
  3146. @register_model
  3147. def vit_large_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3148. """ ViT-L/14 for DINOv2
  3149. """
  3150. model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5)
  3151. model = _create_vision_transformer(
  3152. 'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
  3153. return model
  3154. @register_model
  3155. def vit_giant_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3156. """ ViT-G/14 for DINOv2
  3157. """
  3158. # The hidden_features of SwiGLU is calculated by:
  3159. # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
  3160. # When embed_dim=1536, hidden_features=4096
  3161. # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
  3162. model_args = dict(
  3163. patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
  3164. mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, act_layer=nn.SiLU
  3165. )
  3166. model = _create_vision_transformer(
  3167. 'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
  3168. return model
  3169. @register_model
  3170. def vit_small_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3171. """ ViT-S/14 for DINOv2 w/ 4 registers
  3172. """
  3173. model_args = dict(
  3174. patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5,
  3175. reg_tokens=4, no_embed_class=True,
  3176. )
  3177. model = _create_vision_transformer(
  3178. 'vit_small_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
  3179. return model
  3180. @register_model
  3181. def vit_base_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3182. """ ViT-B/14 for DINOv2 w/ 4 registers
  3183. """
  3184. model_args = dict(
  3185. patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
  3186. reg_tokens=4, no_embed_class=True,
  3187. )
  3188. model = _create_vision_transformer(
  3189. 'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
  3190. return model
  3191. @register_model
  3192. def vit_large_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3193. """ ViT-L/14 for DINOv2 w/ 4 registers
  3194. """
  3195. model_args = dict(
  3196. patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5,
  3197. reg_tokens=4, no_embed_class=True,
  3198. )
  3199. model = _create_vision_transformer(
  3200. 'vit_large_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
  3201. return model
  3202. @register_model
  3203. def vit_giant_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3204. """ ViT-G/14 for DINOv2
  3205. """
  3206. # The hidden_features of SwiGLU is calculated by:
  3207. # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
  3208. # When embed_dim=1536, hidden_features=4096
  3209. # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
  3210. model_args = dict(
  3211. patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, mlp_ratio=2.66667 * 2,
  3212. mlp_layer=SwiGLUPacked, act_layer=nn.SiLU, reg_tokens=4, no_embed_class=True,
  3213. )
  3214. model = _create_vision_transformer(
  3215. 'vit_giant_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
  3216. return model
  3217. @register_model
  3218. def vit_base_patch32_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3219. model_args = dict(
  3220. patch_size=32, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
  3221. act_layer='gelu_tanh',
  3222. )
  3223. model = _create_vision_transformer(
  3224. 'vit_base_patch32_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3225. return model
  3226. @register_model
  3227. def vit_base_patch16_siglip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3228. model_args = dict(
  3229. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
  3230. )
  3231. model = _create_vision_transformer(
  3232. 'vit_base_patch16_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3233. return model
  3234. @register_model
  3235. def vit_base_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3236. model_args = dict(
  3237. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
  3238. )
  3239. model = _create_vision_transformer(
  3240. 'vit_base_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3241. return model
  3242. @register_model
  3243. def vit_base_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3244. model_args = dict(
  3245. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
  3246. )
  3247. model = _create_vision_transformer(
  3248. 'vit_base_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3249. return model
  3250. @register_model
  3251. def vit_base_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3252. model_args = dict(
  3253. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
  3254. )
  3255. model = _create_vision_transformer(
  3256. 'vit_base_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs))
  3257. return model
  3258. @register_model
  3259. def vit_large_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3260. model_args = dict(
  3261. patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
  3262. )
  3263. model = _create_vision_transformer(
  3264. 'vit_large_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3265. return model
  3266. @register_model
  3267. def vit_large_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3268. model_args = dict(
  3269. patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
  3270. )
  3271. model = _create_vision_transformer(
  3272. 'vit_large_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3273. return model
  3274. @register_model
  3275. def vit_large_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3276. model_args = dict(
  3277. patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
  3278. act_layer='gelu_tanh'
  3279. )
  3280. model = _create_vision_transformer(
  3281. 'vit_large_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs))
  3282. return model
  3283. @register_model
  3284. def vit_so400m_patch14_siglip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3285. model_args = dict(
  3286. patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
  3287. )
  3288. model = _create_vision_transformer(
  3289. 'vit_so400m_patch14_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3290. return model
  3291. @register_model
  3292. def vit_so400m_patch14_siglip_378(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3293. # this is a corrected variant of the 384 with a res properly divisible by patch size (no padding/truncation)
  3294. model_args = dict(
  3295. patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
  3296. )
  3297. model = _create_vision_transformer(
  3298. 'vit_so400m_patch14_siglip_378', pretrained=pretrained, **dict(model_args, **kwargs))
  3299. return model
  3300. @register_model
  3301. def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3302. model_args = dict(
  3303. patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
  3304. )
  3305. model = _create_vision_transformer(
  3306. 'vit_so400m_patch14_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3307. return model
  3308. @register_model
  3309. def vit_so400m_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3310. model_args = dict(
  3311. patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
  3312. act_layer='gelu_tanh',
  3313. )
  3314. model = _create_vision_transformer(
  3315. 'vit_so400m_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3316. return model
  3317. @register_model
  3318. def vit_so400m_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3319. model_args = dict(
  3320. patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
  3321. act_layer='gelu_tanh',
  3322. )
  3323. model = _create_vision_transformer(
  3324. 'vit_so400m_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3325. return model
  3326. @register_model
  3327. def vit_so400m_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3328. model_args = dict(
  3329. patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
  3330. act_layer='gelu_tanh',
  3331. )
  3332. model = _create_vision_transformer(
  3333. 'vit_so400m_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs))
  3334. return model
  3335. @register_model
  3336. def vit_giantopt_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3337. model_args = dict(
  3338. patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False, global_pool='map',
  3339. act_layer='gelu_tanh',
  3340. )
  3341. model = _create_vision_transformer(
  3342. 'vit_giantopt_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3343. return model
  3344. @register_model
  3345. def vit_giantopt_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3346. model_args = dict(
  3347. patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False, global_pool='map',
  3348. act_layer='gelu_tanh',
  3349. )
  3350. model = _create_vision_transformer(
  3351. 'vit_giantopt_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3352. return model
  3353. @register_model
  3354. def vit_base_patch32_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3355. model_args = dict(
  3356. patch_size=32, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
  3357. act_layer='gelu_tanh',
  3358. )
  3359. model = _create_vision_transformer(
  3360. 'vit_base_patch32_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3361. return model
  3362. @register_model
  3363. def vit_base_patch16_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3364. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3365. model_args = dict(
  3366. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
  3367. )
  3368. model = _create_vision_transformer(
  3369. 'vit_base_patch16_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3370. return model
  3371. @register_model
  3372. def vit_base_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3373. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3374. model_args = dict(
  3375. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
  3376. )
  3377. model = _create_vision_transformer(
  3378. 'vit_base_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3379. return model
  3380. @register_model
  3381. def vit_base_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3382. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3383. model_args = dict(
  3384. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
  3385. )
  3386. model = _create_vision_transformer(
  3387. 'vit_base_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3388. return model
  3389. @register_model
  3390. def vit_base_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3391. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3392. model_args = dict(
  3393. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
  3394. )
  3395. model = _create_vision_transformer(
  3396. 'vit_base_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs))
  3397. return model
  3398. @register_model
  3399. def vit_large_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3400. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3401. model_args = dict(
  3402. patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='avg', fc_norm=False,
  3403. )
  3404. model = _create_vision_transformer(
  3405. 'vit_large_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3406. return model
  3407. @register_model
  3408. def vit_large_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3409. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3410. model_args = dict(
  3411. patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='avg', fc_norm=False,
  3412. )
  3413. model = _create_vision_transformer(
  3414. 'vit_large_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3415. return model
  3416. @register_model
  3417. def vit_large_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3418. model_args = dict(
  3419. patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False,
  3420. global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
  3421. )
  3422. model = _create_vision_transformer(
  3423. 'vit_large_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs))
  3424. return model
  3425. @register_model
  3426. def vit_so400m_patch14_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3427. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3428. model_args = dict(
  3429. patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
  3430. class_token=False, global_pool='avg', fc_norm=False,
  3431. )
  3432. model = _create_vision_transformer(
  3433. 'vit_so400m_patch14_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3434. return model
  3435. @register_model
  3436. def vit_so400m_patch14_siglip_gap_378(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3437. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3438. model_args = dict(
  3439. patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
  3440. class_token=False, global_pool='avg', fc_norm=False,
  3441. )
  3442. model = _create_vision_transformer(
  3443. 'vit_so400m_patch14_siglip_gap_378', pretrained=pretrained, **dict(model_args, **kwargs))
  3444. return model
  3445. @register_model
  3446. def vit_so400m_patch14_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3447. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3448. model_args = dict(
  3449. patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
  3450. class_token=False, global_pool='avg', fc_norm=False,
  3451. )
  3452. model = _create_vision_transformer(
  3453. 'vit_so400m_patch14_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3454. return model
  3455. @register_model
  3456. def vit_so400m_patch14_siglip_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3457. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3458. model_args = dict(
  3459. patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
  3460. class_token=False, global_pool='avg', fc_norm=False,
  3461. )
  3462. model = _create_vision_transformer(
  3463. 'vit_so400m_patch14_siglip_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
  3464. return model
  3465. @register_model
  3466. def vit_so400m_patch14_siglip_gap_896(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3467. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3468. model_args = dict(
  3469. patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
  3470. class_token=False, global_pool='avg', fc_norm=False,
  3471. )
  3472. model = _create_vision_transformer(
  3473. 'vit_so400m_patch14_siglip_gap_896', pretrained=pretrained, **dict(model_args, **kwargs))
  3474. return model
  3475. @register_model
  3476. def vit_so400m_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3477. """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
  3478. model_args = dict(
  3479. patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
  3480. class_token=False, global_pool='avg', fc_norm=False, act_layer='gelu_tanh',
  3481. )
  3482. model = _create_vision_transformer(
  3483. 'vit_so400m_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3484. return model
  3485. @register_model
  3486. def vit_so400m_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3487. model_args = dict(
  3488. patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False,
  3489. global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
  3490. )
  3491. model = _create_vision_transformer(
  3492. 'vit_so400m_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3493. return model
  3494. @register_model
  3495. def vit_so400m_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3496. model_args = dict(
  3497. patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False,
  3498. global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
  3499. )
  3500. model = _create_vision_transformer(
  3501. 'vit_so400m_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs))
  3502. return model
  3503. @register_model
  3504. def vit_giantopt_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3505. model_args = dict(
  3506. patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False,
  3507. global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
  3508. )
  3509. model = _create_vision_transformer(
  3510. 'vit_giantopt_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3511. return model
  3512. @register_model
  3513. def vit_giantopt_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3514. model_args = dict(
  3515. patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False,
  3516. global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
  3517. )
  3518. model = _create_vision_transformer(
  3519. 'vit_giantopt_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3520. return model
  3521. @register_model
  3522. def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3523. model_args = dict(
  3524. patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5,
  3525. class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
  3526. )
  3527. model = _create_vision_transformer(
  3528. 'vit_wee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3529. return model
  3530. @register_model
  3531. def vit_pwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3532. model_args = dict(
  3533. patch_size=16, embed_dim=256, depth=16, num_heads=4, init_values=1e-5, mlp_ratio=5,
  3534. class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=ParallelScalingBlock,
  3535. )
  3536. model = _create_vision_transformer(
  3537. 'vit_pwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3538. return model
  3539. @register_model
  3540. def vit_little_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3541. model_args = dict(
  3542. patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6,
  3543. class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
  3544. )
  3545. model = _create_vision_transformer(
  3546. 'vit_little_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3547. return model
  3548. @register_model
  3549. def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3550. model_args = dict(
  3551. patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6,
  3552. class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
  3553. )
  3554. model = _create_vision_transformer(
  3555. 'vit_little_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3556. return model
  3557. @register_model
  3558. def vit_medium_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3559. model_args = dict(
  3560. patch_size=16, embed_dim=512, depth=12, num_heads=8, init_values=1e-5,
  3561. class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
  3562. )
  3563. model = _create_vision_transformer(
  3564. 'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3565. return model
  3566. @register_model
  3567. def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3568. model_args = dict(
  3569. patch_size=16, embed_dim=512, depth=12, num_heads=8, init_values=1e-5,
  3570. class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
  3571. )
  3572. model = _create_vision_transformer(
  3573. 'vit_medium_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3574. return model
  3575. @register_model
  3576. def vit_mediumd_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3577. model_args = dict(
  3578. patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
  3579. class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
  3580. )
  3581. model = _create_vision_transformer(
  3582. 'vit_mediumd_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3583. return model
  3584. @register_model
  3585. def vit_mediumd_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3586. model_args = dict(
  3587. patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
  3588. class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
  3589. )
  3590. model = _create_vision_transformer(
  3591. 'vit_mediumd_patch16_reg4_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3592. return model
  3593. @register_model
  3594. def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3595. model_args = dict(
  3596. patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5,
  3597. class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
  3598. )
  3599. model = _create_vision_transformer(
  3600. 'vit_betwixt_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3601. return model
  3602. @register_model
  3603. def vit_betwixt_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3604. model_args = dict(
  3605. patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5,
  3606. class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
  3607. )
  3608. model = _create_vision_transformer(
  3609. 'vit_betwixt_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3610. return model
  3611. @register_model
  3612. def vit_betwixt_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3613. model_args = dict(
  3614. patch_size=16, embed_dim=640, depth=12, num_heads=10, init_values=1e-5,
  3615. class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
  3616. )
  3617. model = _create_vision_transformer(
  3618. 'vit_betwixt_patch16_reg4_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3619. return model
  3620. @register_model
  3621. def vit_base_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3622. model_args = dict(
  3623. patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False,
  3624. no_embed_class=True, global_pool='avg', reg_tokens=4,
  3625. )
  3626. model = _create_vision_transformer(
  3627. 'vit_base_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3628. return model
  3629. @register_model
  3630. def vit_so150m_patch16_reg4_map_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3631. """ SO150M (shape optimized, but diff than paper def, optimized for GPU) """
  3632. model_args = dict(
  3633. patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
  3634. class_token=False, reg_tokens=4, global_pool='map',
  3635. )
  3636. model = _create_vision_transformer(
  3637. 'vit_so150m_patch16_reg4_map_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3638. return model
  3639. @register_model
  3640. def vit_so150m_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3641. """ SO150M (shape optimized, but diff than paper def, optimized for GPU) """
  3642. model_args = dict(
  3643. patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
  3644. class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False,
  3645. )
  3646. model = _create_vision_transformer(
  3647. 'vit_so150m_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3648. return model
  3649. @register_model
  3650. def vit_so150m_patch16_reg4_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3651. """ SO150M (shape optimized, but diff than paper def, optimized for GPU) """
  3652. model_args = dict(
  3653. patch_size=16, embed_dim=896, depth=18, num_heads=14, mlp_ratio=2.572,
  3654. class_token=False, reg_tokens=4, global_pool='avg', fc_norm=False,
  3655. )
  3656. model = _create_vision_transformer(
  3657. 'vit_so150m_patch16_reg4_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3658. return model
  3659. @register_model
  3660. def vit_so150m2_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3661. """ SO150M v2 (shape optimized, but diff than paper def, optimized for GPU) """
  3662. model_args = dict(
  3663. patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
  3664. qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg',
  3665. )
  3666. model = _create_vision_transformer(
  3667. 'vit_so150m2_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  3668. return model
  3669. @register_model
  3670. def vit_so150m2_patch16_reg1_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3671. """ SO150M v2 (shape optimized, but diff than paper def, optimized for GPU) """
  3672. model_args = dict(
  3673. patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
  3674. qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg',
  3675. )
  3676. model = _create_vision_transformer(
  3677. 'vit_so150m2_patch16_reg1_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
  3678. return model
  3679. @register_model
  3680. def vit_so150m2_patch16_reg1_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3681. """ SO150M v2 (shape optimized, but diff than paper def, optimized for GPU) """
  3682. model_args = dict(
  3683. patch_size=16, embed_dim=832, depth=21, num_heads=13, mlp_ratio=34/13, init_values=1e-5,
  3684. qkv_bias=False, class_token=False, reg_tokens=1, global_pool='avg',
  3685. )
  3686. model = _create_vision_transformer(
  3687. 'vit_so150m2_patch16_reg1_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
  3688. return model
  3689. @register_model
  3690. def vit_intern300m_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3691. model_args = dict(
  3692. patch_size=14, embed_dim=1024, depth=24, num_heads=16,
  3693. init_values=0.1, final_norm=False, dynamic_img_size=True,
  3694. )
  3695. model = _create_vision_transformer(
  3696. 'vit_intern300m_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  3697. return model
  3698. @register_model
  3699. def aimv2_large_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3700. """ ViT Large AIM-v2 model
  3701. """
  3702. model_args = dict(
  3703. patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
  3704. mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  3705. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  3706. )
  3707. model = _create_vision_transformer(
  3708. 'aimv2_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3709. return model
  3710. @register_model
  3711. def aimv2_huge_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3712. """ ViT Huge AIM-v2 model
  3713. """
  3714. model_args = dict(
  3715. patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
  3716. mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  3717. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  3718. )
  3719. model = _create_vision_transformer(
  3720. 'aimv2_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3721. return model
  3722. @register_model
  3723. def aimv2_1b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3724. """ ViT 1B AIM-v2 model
  3725. """
  3726. model_args = dict(
  3727. patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
  3728. mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  3729. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  3730. )
  3731. model = _create_vision_transformer(
  3732. 'aimv2_1b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3733. return model
  3734. @register_model
  3735. def aimv2_3b_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3736. """ ViT 3B AIM-v2 model
  3737. """
  3738. model_args = dict(
  3739. patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
  3740. mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  3741. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  3742. )
  3743. model = _create_vision_transformer(
  3744. 'aimv2_3b_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3745. return model
  3746. @register_model
  3747. def aimv2_large_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3748. """ ViT Large AIM-v2 model
  3749. """
  3750. model_args = dict(
  3751. patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
  3752. mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  3753. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  3754. )
  3755. model = _create_vision_transformer(
  3756. 'aimv2_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  3757. return model
  3758. @register_model
  3759. def aimv2_huge_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3760. """ ViT Huge AIM-v2 model
  3761. """
  3762. model_args = dict(
  3763. patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
  3764. mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  3765. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  3766. )
  3767. model = _create_vision_transformer(
  3768. 'aimv2_huge_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  3769. return model
  3770. @register_model
  3771. def aimv2_1b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3772. """ ViT 1B AIM-v2 model
  3773. """
  3774. model_args = dict(
  3775. patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
  3776. mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  3777. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  3778. )
  3779. model = _create_vision_transformer(
  3780. 'aimv2_1b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  3781. return model
  3782. @register_model
  3783. def aimv2_3b_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3784. """ ViT 3B AIM-v2 model
  3785. """
  3786. model_args = dict(
  3787. patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
  3788. mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  3789. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  3790. )
  3791. model = _create_vision_transformer(
  3792. 'aimv2_3b_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  3793. return model
  3794. @register_model
  3795. def aimv2_large_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3796. """ ViT Large AIM-v2 model
  3797. """
  3798. model_args = dict(
  3799. patch_size=14, embed_dim=1024, depth=24, num_heads=8, class_token=False, fc_norm=False,
  3800. mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  3801. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  3802. )
  3803. model = _create_vision_transformer(
  3804. 'aimv2_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  3805. return model
  3806. @register_model
  3807. def aimv2_huge_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3808. """ ViT Huge AIM-v2 model
  3809. """
  3810. model_args = dict(
  3811. patch_size=14, embed_dim=1536, depth=24, num_heads=12, class_token=False, fc_norm=False,
  3812. mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  3813. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  3814. )
  3815. model = _create_vision_transformer(
  3816. 'aimv2_huge_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  3817. return model
  3818. @register_model
  3819. def aimv2_1b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3820. """ ViT 1B AIM-v2 model
  3821. """
  3822. model_args = dict(
  3823. patch_size=14, embed_dim=2048, depth=24, num_heads=16, class_token=False, fc_norm=False,
  3824. mlp_ratio=2.75, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  3825. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  3826. )
  3827. model = _create_vision_transformer(
  3828. 'aimv2_1b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  3829. return model
  3830. @register_model
  3831. def aimv2_3b_patch14_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3832. """ ViT 3B AIM-v2 model
  3833. """
  3834. model_args = dict(
  3835. patch_size=14, embed_dim=3072, depth=24, num_heads=24, class_token=False, fc_norm=False,
  3836. mlp_ratio=2.6667, global_pool='avg', qkv_bias=False, proj_bias=False, act_layer='silu',
  3837. norm_layer=partial(RmsNorm, eps=1e-5), embed_norm_layer=partial(RmsNorm, eps=1e-5), mlp_layer=SwiGLU,
  3838. )
  3839. model = _create_vision_transformer(
  3840. 'aimv2_3b_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  3841. return model
  3842. @register_model
  3843. def test_vit(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3844. """ ViT Test
  3845. """
  3846. model_args = dict(patch_size=16, embed_dim=64, depth=6, num_heads=2, mlp_ratio=3, dynamic_img_size=True)
  3847. model = _create_vision_transformer('test_vit', pretrained=pretrained, **dict(model_args, **kwargs))
  3848. return model
  3849. @register_model
  3850. def test_vit2(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3851. """ ViT Test
  3852. """
  3853. model_args = dict(
  3854. patch_size=16, embed_dim=64, depth=8, num_heads=2, mlp_ratio=3,
  3855. class_token=False, reg_tokens=1, global_pool='avg', init_values=1e-5, dynamic_img_size=True)
  3856. model = _create_vision_transformer('test_vit2', pretrained=pretrained, **dict(model_args, **kwargs))
  3857. return model
  3858. @register_model
  3859. def test_vit3(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3860. """ ViT Test
  3861. """
  3862. model_args = dict(
  3863. patch_size=16, embed_dim=96, depth=9, num_heads=3, mlp_ratio=2,
  3864. class_token=False, reg_tokens=1, global_pool='map', pool_include_prefix=True, init_values=1e-5)
  3865. model = _create_vision_transformer('test_vit3', pretrained=pretrained, **dict(model_args, **kwargs))
  3866. return model
  3867. @register_model
  3868. def test_vit4(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3869. """ ViT Test
  3870. """
  3871. model_args = dict(
  3872. patch_size=16, embed_dim=96, depth=9, num_heads=3, mlp_ratio=3,
  3873. class_token=False, reg_tokens=1, global_pool='avg', init_values=1e-5, dynamic_img_size=True,
  3874. norm_layer='rmsnorm',
  3875. )
  3876. model = _create_vision_transformer('test_vit4', pretrained=pretrained, **dict(model_args, **kwargs))
  3877. return model
  3878. @register_model
  3879. def beit3_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3880. """ BEiT3 Base model (ViT-Base size) with patch size 16x16.
  3881. Remapped to VisionTransformer with scale_norm=True.
  3882. """
  3883. model_args = dict(
  3884. patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
  3885. scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg',
  3886. norm_layer=partial(LayerNorm, eps=1e-5)
  3887. )
  3888. model = _create_vision_transformer('beit3_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3889. return model
  3890. @register_model
  3891. def beit3_large_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3892. """ BEiT3 Large model (ViT-Large size) with patch size 16x16.
  3893. Remapped to VisionTransformer with scale_norm=True.
  3894. """
  3895. model_args = dict(
  3896. patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4,
  3897. scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg',
  3898. norm_layer=partial(LayerNorm, eps=1e-5),
  3899. )
  3900. model = _create_vision_transformer('beit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3901. return model
  3902. @register_model
  3903. def beit3_giant_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3904. """ BEiT3 Giant model with patch size 14x14.
  3905. Remapped to VisionTransformer with scale_norm=True.
  3906. """
  3907. model_args = dict(
  3908. patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637,
  3909. scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg',
  3910. norm_layer=partial(LayerNorm, eps=1e-5),
  3911. )
  3912. model = _create_vision_transformer('beit3_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  3913. return model
  3914. @register_model
  3915. def beit3_giant_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer:
  3916. """ BEiT3 Giant model with patch size 14x14 and image size 336x336.
  3917. Remapped to VisionTransformer with scale_norm=True.
  3918. """
  3919. model_args = dict(
  3920. img_size=336, patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=4.3637,
  3921. scale_attn_norm=True, scale_mlp_norm=True, class_token=True, global_pool='avg',
  3922. norm_layer=partial(LayerNorm, eps=1e-5),
  3923. )
  3924. model = _create_vision_transformer('beit3_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  3925. return model
  3926. register_model_deprecations(__name__, {
  3927. 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',
  3928. 'vit_small_patch32_224_in21k': 'vit_small_patch32_224.augreg_in21k',
  3929. 'vit_small_patch16_224_in21k': 'vit_small_patch16_224.augreg_in21k',
  3930. 'vit_base_patch32_224_in21k': 'vit_base_patch32_224.augreg_in21k',
  3931. 'vit_base_patch16_224_in21k': 'vit_base_patch16_224.augreg_in21k',
  3932. 'vit_base_patch8_224_in21k': 'vit_base_patch8_224.augreg_in21k',
  3933. 'vit_large_patch32_224_in21k': 'vit_large_patch32_224.orig_in21k',
  3934. 'vit_large_patch16_224_in21k': 'vit_large_patch16_224.augreg_in21k',
  3935. 'vit_huge_patch14_224_in21k': 'vit_huge_patch14_224.orig_in21k',
  3936. 'vit_base_patch32_224_sam': 'vit_base_patch32_224.sam',
  3937. 'vit_base_patch16_224_sam': 'vit_base_patch16_224.sam',
  3938. 'vit_small_patch16_224_dino': 'vit_small_patch16_224.dino',
  3939. 'vit_small_patch8_224_dino': 'vit_small_patch8_224.dino',
  3940. 'vit_base_patch16_224_dino': 'vit_base_patch16_224.dino',
  3941. 'vit_base_patch8_224_dino': 'vit_base_patch8_224.dino',
  3942. 'vit_base_patch16_224_miil_in21k': 'vit_base_patch16_224_miil.in21k',
  3943. 'vit_base_patch32_224_clip_laion2b': 'vit_base_patch32_clip_224.laion2b',
  3944. 'vit_large_patch14_224_clip_laion2b': 'vit_large_patch14_clip_224.laion2b',
  3945. 'vit_huge_patch14_224_clip_laion2b': 'vit_huge_patch14_clip_224.laion2b',
  3946. 'vit_giant_patch14_224_clip_laion2b': 'vit_giant_patch14_clip_224.laion2b',
  3947. })