eva.py 106 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936
  1. """ EVA
  2. EVA ViT from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636
  3. This file contains a number of ViT variants the utilise ROPE position embeddings, SwiGLU and other additions:
  4. * EVA & EVA02 model implementations that evolved from BEiT, additional models in vision_transformer.py.
  5. * `timm` original SBB ViT w/ ROPE position embeddings
  6. * Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)
  7. * ROPE-ViT from Naver AI (https://arxiv.org/abs/2403.13298)
  8. * DINOv3 from META AI Research (https://arxiv.org/abs/2508.10104)
  9. @article{EVA,
  10. title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale},
  11. author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang,
  12. Tiejun and Wang, Xinlong and Cao, Yue},
  13. journal={arXiv preprint arXiv:2211.07636},
  14. year={2022}
  15. }
  16. EVA-02: A Visual Representation for Neon Genesis - https://arxiv.org/abs/2303.11331
  17. @article{EVA02,
  18. title={EVA-02: A Visual Representation for Neon Genesis},
  19. author={Fang, Yuxin and Sun, Quan and Wang, Xinggang and Huang, Tiejun and Wang, Xinlong and Cao, Yue},
  20. journal={arXiv preprint arXiv:2303.11331},
  21. year={2023}
  22. }
  23. @article{bolya2025perception,
  24. title={Perception encoder: The best visual embeddings are not at the output of the network},
  25. author={Bolya, Daniel and Huang, Po-Yao and Sun, Peize and Cho, Jang Hyun and Madotto, Andrea and Wei, Chen and Ma,
  26. Tengyu and Zhi, Jiale and Rajasegaran, Jathushan and Rasheed, Hanoona and others},
  27. journal={arXiv preprint arXiv:2504.13181},
  28. year={2025}
  29. }
  30. @inproceedings{heo2024rotary,
  31. title={Rotary position embedding for vision transformer},
  32. author={Heo, Byeongho and Park, Song and Han, Dongyoon and Yun, Sangdoo},
  33. booktitle={European Conference on Computer Vision},
  34. pages={289--305},
  35. year={2024},
  36. organization={Springer}
  37. }
  38. @article{simeoni2025dinov3,
  39. title={{DINOv3}},
  40. author={Sim{\'e}oni, Oriane and Vo, Huy V. and Seitzer, Maximilian and Baldassarre, Federico and Oquab, Maxime
  41. and Jose, Cijo and Khalidov, Vasil and Szafraniec, Marc and Yi, Seungeun and Ramamonjisoa, Micha{\"e}l
  42. and Massa, Francisco and Haziza, Daniel and Wehrstedt, Luca and Wang, Jianyuan and Darcet, Timoth{\'e}e
  43. and Moutakanni, Th{\'e}o and Sentana, Leonel and Roberts, Claire and Vedaldi, Andrea and Tolan, Jamie
  44. and Brandt, John and Couprie, Camille and Mairal, Julien and J{\'e}gou, Herv{\'e} and Labatut, Patrick
  45. and Bojanowski, Piotr},
  46. year={2025},
  47. eprint={2508.10104},
  48. url={https://arxiv.org/abs/2508.10104},
  49. }
  50. DINOv3 code was a modification of existing EVA model and support modules, so licensed under Apache-2.0 like timm.
  51. Weights from META remain under DINOv3 License (https://ai.meta.com/resources/models-and-libraries/dinov3-license/).
  52. Modifications by / Copyright 2023 Ross Wightman, original copyrights below
  53. """
  54. # EVA models Copyright (c) 2022 BAAI-Vision
  55. # EVA02 models Copyright (c) 2023 BAAI-Vision
  56. import math
  57. import os
  58. from functools import partial
  59. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
  60. import torch
  61. import torch.nn as nn
  62. import torch.nn.functional as F
  63. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
  64. from timm.layers import (
  65. PatchEmbed,
  66. Mlp,
  67. GluMlp,
  68. SwiGLU,
  69. LayerNorm,
  70. DropPath, calculate_drop_path_rates,
  71. PatchDropoutWithIndices,
  72. create_rope_embed,
  73. apply_rot_embed_cat,
  74. apply_keep_indices_nlc,
  75. trunc_normal_,
  76. resample_patch_embed,
  77. resample_abs_pos_embed,
  78. global_pool_nlc,
  79. to_2tuple,
  80. use_fused_attn,
  81. maybe_add_mask,
  82. AttentionRope,
  83. AttentionPoolLatent,
  84. )
  85. from ._builder import build_model_with_cfg
  86. from ._features import feature_take_indices
  87. from ._manipulate import checkpoint
  88. from ._registry import generate_default_cfgs, register_model
  89. __all__ = ['Eva']
  90. class EvaAttention(nn.Module):
  91. """ EVA Attention with ROPE, no k-bias, and fused/unfused qkv options
  92. """
  93. fused_attn: torch.jit.Final[bool]
  94. def __init__(
  95. self,
  96. dim: int,
  97. num_heads: int = 8,
  98. qkv_bias: bool = True,
  99. qkv_fused: bool = True,
  100. qkv_bias_separate: bool = False,
  101. num_prefix_tokens: int = 1,
  102. attn_drop: float = 0.,
  103. proj_drop: float = 0.,
  104. attn_head_dim: Optional[int] = None,
  105. norm_layer: Optional[Callable] = None,
  106. qk_norm: bool = False,
  107. scale_norm: bool = True,
  108. rotate_half: bool = False,
  109. device=None,
  110. dtype=None,
  111. ):
  112. """
  113. Args:
  114. dim: Input dimension of the token embeddings
  115. num_heads: Number of attention heads
  116. qkv_bias: Whether to add a bias term to the query, key, and value projections
  117. qkv_fused: Whether qkv projections are fused into one projection or separate
  118. qkv_bias_separate: Whether to apply bias to qkv as a separate addition or part of F.linear() call
  119. num_prefix_tokens: Number of reg/cls tokens at the beginning of the sequence that
  120. should not have position embeddings applied
  121. attn_drop: Dropout rate for attention weights
  122. proj_drop: Dropout rate for the output projection
  123. attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
  124. norm_layer: Normalization layer constructor to use for QK and scale normalization
  125. qk_norm: Enable normalization of query (Q) and key (K) vectors with norm_layer
  126. scale_norm: Enable normalization (scaling) of attention output with norm_layer
  127. rotate_half: Use half rotation layout instead of interleaved
  128. """
  129. dd = {'device': device, 'dtype': dtype}
  130. super().__init__()
  131. if scale_norm or qk_norm:
  132. assert norm_layer is not None, 'norm_layer must be provided if qk_norm or scale_norm is True'
  133. self.num_heads = num_heads
  134. head_dim = dim // num_heads
  135. if attn_head_dim is not None:
  136. head_dim = attn_head_dim
  137. attn_dim = head_dim * self.num_heads
  138. self.scale = head_dim ** -0.5
  139. self.num_prefix_tokens = num_prefix_tokens
  140. self.fused_attn = use_fused_attn()
  141. self.qkv_bias_separate = qkv_bias_separate
  142. self.rotate_half = rotate_half
  143. if qkv_fused:
  144. self.qkv = nn.Linear(dim, attn_dim * 3, bias=False, **dd)
  145. self.q_proj = self.k_proj = self.v_proj = None
  146. if qkv_bias:
  147. self.q_bias = nn.Parameter(torch.zeros(attn_dim, **dd))
  148. self.register_buffer('k_bias', torch.zeros(attn_dim, **dd), persistent=False)
  149. self.v_bias = nn.Parameter(torch.zeros(attn_dim, **dd))
  150. else:
  151. self.q_bias = self.k_bias = self.v_bias = None
  152. else:
  153. self.q_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
  154. self.k_proj = nn.Linear(dim, attn_dim, bias=False, **dd)
  155. self.v_proj = nn.Linear(dim, attn_dim, bias=qkv_bias, **dd)
  156. self.qkv = None
  157. self.q_bias = self.k_bias = self.v_bias = None
  158. self.q_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
  159. self.k_norm = norm_layer(self.head_dim, **dd) if qk_norm else nn.Identity()
  160. self.attn_drop = nn.Dropout(attn_drop)
  161. self.norm = norm_layer(attn_dim, **dd) if scale_norm else nn.Identity()
  162. self.proj = nn.Linear(attn_dim, dim, **dd)
  163. self.proj_drop = nn.Dropout(proj_drop)
  164. def forward(
  165. self,
  166. x,
  167. rope: Optional[torch.Tensor] = None,
  168. attn_mask: Optional[torch.Tensor] = None,
  169. ):
  170. """Forward pass for the attention module.
  171. Args:
  172. x: Input tensor of shape (batch_size, sequence_length, embedding_dim)
  173. rope: Rotary position embeddings tensor for position-aware attention
  174. attn_mask: Optional attention mask to apply during attention computation
  175. Returns:
  176. Tensor of shape (batch_size, sequence_length, embedding_dim)
  177. """
  178. B, N, C = x.shape
  179. if self.qkv is not None:
  180. if self.q_bias is None:
  181. qkv = self.qkv(x)
  182. else:
  183. qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
  184. if self.qkv_bias_separate:
  185. qkv = self.qkv(x)
  186. qkv += qkv_bias
  187. else:
  188. qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
  189. qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  190. q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
  191. else:
  192. q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C
  193. k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
  194. v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
  195. q, k = self.q_norm(q), self.k_norm(k)
  196. if rope is not None:
  197. npt = self.num_prefix_tokens
  198. half = getattr(self, 'rotate_half', False)
  199. q = torch.cat([q[:, :, :npt, :], apply_rot_embed_cat(q[:, :, npt:, :], rope, half=half)], dim=2).type_as(v)
  200. k = torch.cat([k[:, :, :npt, :], apply_rot_embed_cat(k[:, :, npt:, :], rope, half=half)], dim=2).type_as(v)
  201. if self.fused_attn:
  202. x = F.scaled_dot_product_attention(
  203. q, k, v,
  204. attn_mask=attn_mask,
  205. dropout_p=self.attn_drop.p if self.training else 0.,
  206. )
  207. else:
  208. q = q * self.scale
  209. attn = (q @ k.transpose(-2, -1))
  210. attn = maybe_add_mask(attn, attn_mask)
  211. attn = attn.softmax(dim=-1)
  212. attn = self.attn_drop(attn)
  213. x = attn @ v
  214. x = x.transpose(1, 2).reshape(B, N, C)
  215. x = self.norm(x)
  216. x = self.proj(x)
  217. x = self.proj_drop(x)
  218. return x
  219. class EvaBlock(nn.Module):
  220. def __init__(
  221. self,
  222. dim: int,
  223. num_heads: int,
  224. qkv_bias: bool = True,
  225. qkv_fused: bool = True,
  226. mlp_ratio: float = 4.,
  227. swiglu_mlp: bool = False,
  228. swiglu_align_to: int = 0,
  229. scale_mlp: bool = False,
  230. scale_attn_inner: bool = False,
  231. num_prefix_tokens: int = 1,
  232. attn_type: str = 'eva',
  233. rotate_half: bool = False,
  234. proj_drop: float = 0.,
  235. attn_drop: float = 0.,
  236. drop_path: float = 0.,
  237. init_values: Optional[float] = None,
  238. act_layer: Callable = nn.GELU,
  239. norm_layer: Callable = LayerNorm,
  240. attn_head_dim: Optional[int] = None,
  241. device=None,
  242. dtype=None,
  243. **kwargs,
  244. ):
  245. """ Initialize the EVA transformer block.
  246. Args:
  247. dim: Input dimension of the token embeddings
  248. num_heads: Number of attention heads
  249. qkv_bias: Whether to use bias terms in query, key, value projections
  250. qkv_fused: Whether to use a single projection for query, key, value
  251. mlp_ratio: Ratio of MLP hidden dimension to input dimension
  252. swiglu_mlp: Whether to use SwiGLU activation in the MLP
  253. scale_mlp: Whether to use normalization in the MLP
  254. scale_attn_inner: Whether to use normalization within the attention mechanism
  255. num_prefix_tokens: Number of tokens at the beginning of the sequence (class tokens, etc.)
  256. attn_type: Type of attention module to use ('eva' or 'rope')
  257. proj_drop: Dropout rate for projection layers
  258. attn_drop: Dropout rate for attention matrix
  259. drop_path: Stochastic depth rate
  260. init_values: Initial value for LayerScale, None = no LayerScale
  261. act_layer: Activation layer constructor
  262. norm_layer: Normalization layer constructor
  263. attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
  264. """
  265. dd = {'device': device, 'dtype': dtype}
  266. super().__init__()
  267. self.norm1 = norm_layer(dim, **dd)
  268. attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention
  269. self.attn = attn_cls(
  270. dim,
  271. num_heads=num_heads,
  272. qkv_bias=qkv_bias,
  273. qkv_fused=qkv_fused,
  274. num_prefix_tokens=num_prefix_tokens,
  275. attn_drop=attn_drop,
  276. proj_drop=proj_drop,
  277. attn_head_dim=attn_head_dim,
  278. norm_layer=norm_layer,
  279. scale_norm=scale_attn_inner,
  280. rotate_half=rotate_half,
  281. **dd,
  282. )
  283. self.gamma_1 = nn.Parameter(init_values * torch.ones(dim, **dd)) if init_values is not None else None
  284. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  285. self.norm2 = norm_layer(dim, **dd)
  286. hidden_features = int(dim * mlp_ratio)
  287. if swiglu_mlp:
  288. if scale_mlp or swiglu_align_to:
  289. # when norm in SwiGLU used or alignment enabled, an impl with separate fc for gate & x is used
  290. self.mlp = SwiGLU(
  291. in_features=dim,
  292. hidden_features=hidden_features,
  293. norm_layer=norm_layer if scale_mlp else None,
  294. drop=proj_drop,
  295. align_to=swiglu_align_to,
  296. **dd,
  297. )
  298. else:
  299. # w/o any extra norm, an impl with packed weights is used
  300. self.mlp = GluMlp(
  301. in_features=dim,
  302. hidden_features=hidden_features * 2,
  303. norm_layer=norm_layer if scale_mlp else None,
  304. act_layer=nn.SiLU,
  305. gate_last=False,
  306. drop=proj_drop,
  307. **dd,
  308. )
  309. else:
  310. self.mlp = Mlp(
  311. in_features=dim,
  312. hidden_features=hidden_features,
  313. act_layer=act_layer,
  314. norm_layer=norm_layer if scale_mlp else None,
  315. drop=proj_drop,
  316. **dd,
  317. )
  318. self.gamma_2 = nn.Parameter(init_values * torch.ones(dim, **dd)) if init_values is not None else None
  319. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  320. def forward(
  321. self,
  322. x: torch.Tensor,
  323. rope: Optional[torch.Tensor] = None,
  324. attn_mask: Optional[torch.Tensor] = None,
  325. ) -> torch.Tensor:
  326. if self.gamma_1 is None:
  327. x = x + self.drop_path1(self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask))
  328. x = x + self.drop_path2(self.mlp(self.norm2(x)))
  329. else:
  330. x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask))
  331. x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
  332. return x
  333. class EvaBlockPostNorm(nn.Module):
  334. """ EVA block w/ post-norm and support for swiglu, MLP norm scale, ROPE. """
  335. def __init__(
  336. self,
  337. dim: int,
  338. num_heads: int,
  339. qkv_bias: bool = True,
  340. qkv_fused: bool = True,
  341. mlp_ratio: float = 4.,
  342. attn_type: str = 'eva',
  343. rotate_half: bool = False,
  344. swiglu_mlp: bool = False,
  345. swiglu_align_to: int = 0,
  346. scale_mlp: bool = False,
  347. scale_attn_inner: bool = False,
  348. num_prefix_tokens: int = 1,
  349. proj_drop: float = 0.,
  350. attn_drop: float = 0.,
  351. drop_path: float = 0.,
  352. init_values: Optional[float] = None, # ignore for post-norm
  353. act_layer: Callable = nn.GELU,
  354. norm_layer: Callable = nn.LayerNorm,
  355. attn_head_dim: Optional[int] = None,
  356. device=None,
  357. dtype=None,
  358. ):
  359. """ Initialize the post-norm EVA transformer block.
  360. Args:
  361. dim: Input dimension of the token embeddings
  362. num_heads: Number of attention heads
  363. qkv_bias: Whether to use bias terms in query, key, value projections
  364. qkv_fused: Whether to use a single projection for query, key, value
  365. mlp_ratio: Ratio of MLP hidden dimension to input dimension
  366. swiglu_mlp: Whether to use SwiGLU activation in the MLP
  367. scale_mlp: Whether to use normalization in the MLP
  368. scale_attn_inner: Whether to use normalization within the attention mechanism
  369. num_prefix_tokens: Number of tokens at the beginning of the sequence (class tokens, etc.)
  370. attn_type: Type of attention module to use ('eva' or 'rope')
  371. proj_drop: Dropout rate for projection layers
  372. attn_drop: Dropout rate for attention matrix
  373. drop_path: Stochastic depth rate
  374. init_values: Initial value for LayerScale, None = no LayerScale (NOTE: ignored for post-norm block)
  375. act_layer: Activation layer constructor
  376. norm_layer: Normalization layer constructor
  377. attn_head_dim: Dimension of each attention head (if None, computed as dim // num_heads)
  378. """
  379. dd = {'device': device, 'dtype': dtype}
  380. super().__init__()
  381. attn_cls = AttentionRope if attn_type == 'rope' else EvaAttention
  382. self.attn = attn_cls(
  383. dim,
  384. num_heads=num_heads,
  385. qkv_bias=qkv_bias,
  386. qkv_fused=qkv_fused,
  387. num_prefix_tokens=num_prefix_tokens,
  388. attn_drop=attn_drop,
  389. proj_drop=proj_drop,
  390. attn_head_dim=attn_head_dim,
  391. norm_layer=norm_layer,
  392. scale_norm=scale_attn_inner,
  393. rotate_half=rotate_half,
  394. **dd,
  395. )
  396. self.norm1 = norm_layer(dim, **dd)
  397. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  398. hidden_features = int(dim * mlp_ratio)
  399. if swiglu_mlp:
  400. if scale_mlp:
  401. # when norm in SwiGLU used, an impl with separate fc for gate & x is used
  402. self.mlp = SwiGLU(
  403. in_features=dim,
  404. hidden_features=hidden_features,
  405. norm_layer=norm_layer if scale_mlp else None,
  406. drop=proj_drop,
  407. align_to=swiglu_align_to,
  408. **dd,
  409. )
  410. else:
  411. # w/o any extra norm, an impl with packed fc1 weights is used, matches existing GluMLP
  412. self.mlp = GluMlp(
  413. in_features=dim,
  414. hidden_features=hidden_features * 2,
  415. norm_layer=norm_layer if scale_mlp else None,
  416. act_layer=nn.SiLU,
  417. gate_last=False,
  418. drop=proj_drop,
  419. **dd,
  420. )
  421. else:
  422. self.mlp = Mlp(
  423. in_features=dim,
  424. hidden_features=hidden_features,
  425. act_layer=act_layer,
  426. norm_layer=norm_layer if scale_mlp else None,
  427. drop=proj_drop,
  428. **dd,
  429. )
  430. self.norm2 = norm_layer(dim, **dd)
  431. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  432. def forward(
  433. self,
  434. x: torch.Tensor,
  435. rope: Optional[torch.Tensor] = None,
  436. attn_mask: Optional[torch.Tensor] = None,
  437. ) -> torch.Tensor:
  438. x = x + self.drop_path1(self.norm1(self.attn(x, rope=rope, attn_mask=attn_mask)))
  439. x = x + self.drop_path2(self.norm2(self.mlp(x)))
  440. return x
  441. class Eva(nn.Module):
  442. """ Eva Vision Transformer w/ Abs & Rotary Pos Embed
  443. This class implements the EVA and EVA02 models that were based on the BEiT ViT variant
  444. * EVA - abs pos embed, global avg pool
  445. * EVA02 - abs + rope pos embed, global avg pool, SwiGLU, scale Norm in MLP (ala normformer)
  446. """
  447. def __init__(
  448. self,
  449. img_size: Union[int, Tuple[int, int]] = 224,
  450. patch_size: Union[int, Tuple[int, int]] = 16,
  451. in_chans: int = 3,
  452. num_classes: int = 1000,
  453. global_pool: str = 'avg',
  454. embed_dim: int = 768,
  455. depth: int = 12,
  456. num_heads: int = 12,
  457. qkv_bias: bool = True,
  458. qkv_fused: bool = True,
  459. mlp_ratio: float = 4.,
  460. swiglu_mlp: bool = False,
  461. swiglu_align_to: int = 0,
  462. scale_mlp: bool = False,
  463. scale_attn_inner: bool = False,
  464. attn_type: str = 'eva',
  465. drop_rate: float = 0.,
  466. pos_drop_rate: float = 0.,
  467. patch_drop_rate: float = 0.,
  468. proj_drop_rate: float = 0.,
  469. attn_drop_rate: float = 0.,
  470. drop_path_rate: float = 0.,
  471. norm_layer: Callable = LayerNorm,
  472. init_values: Optional[float] = None,
  473. class_token: bool = True,
  474. num_reg_tokens: int = 0,
  475. no_embed_class: bool = False,
  476. use_abs_pos_emb: bool = True,
  477. use_rot_pos_emb: bool = False,
  478. rope_type: Optional[str] = 'cat',
  479. rope_grid_offset: float = 0.,
  480. rope_grid_indexing: str = 'ij',
  481. rope_temperature: float = 10000.,
  482. rope_rotate_half: bool = False,
  483. use_post_norm: bool = False,
  484. use_pre_transformer_norm: bool = False,
  485. use_post_transformer_norm: Optional[bool] = None,
  486. use_fc_norm: Optional[bool] = None,
  487. attn_pool_num_heads: Optional[int] = None,
  488. attn_pool_mlp_ratio: Optional[float] = None,
  489. dynamic_img_size: bool = False,
  490. dynamic_img_pad: bool = False,
  491. ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
  492. head_init_scale: float = 0.001,
  493. device=None,
  494. dtype=None,
  495. ):
  496. """Initialize the EVA Vision Transformer model.
  497. Args:
  498. img_size: Input image size (single int for square, or tuple for rectangular)
  499. patch_size: Patch size to divide image into tokens (single int for square, or tuple)
  500. in_chans: Number of input image channels
  501. num_classes: Number of classes (output dim) for classification head (final projection), 0 for pass-through
  502. global_pool: Type of global pooling for final sequence ('avg', 'token', 'map', etc.)
  503. embed_dim: Embedding dimension for tokens
  504. depth: Number of transformer blocks
  505. num_heads: Number of attention heads
  506. qkv_bias: Enable bias for query, key, value projections
  507. qkv_fused: Use a single projection for query, key, value
  508. mlp_ratio: Ratio of mlp hidden dim to embedding dim
  509. swiglu_mlp: Use SwiGLU activation in MLP
  510. scale_mlp: Apply scaling normalization in MLP (normformer style)
  511. scale_attn_inner: Apply scaling normalization inside attention
  512. attn_type: Type of attention module to use
  513. drop_rate: Dropout rate after final projection and pooling
  514. pos_drop_rate: Dropout rate for positional embeddings
  515. patch_drop_rate: Rate of dropping patches during training
  516. proj_drop_rate: Dropout rate for projections
  517. attn_drop_rate: Dropout rate for attention
  518. drop_path_rate: Stochastic depth rate
  519. norm_layer: Normalization layer constructor
  520. init_values: Initial layer-scale values
  521. class_token: Use class token
  522. num_reg_tokens: Number of additional learnable 'register' tokens to add to the sequence
  523. no_embed_class: Don't include position embeddings for class (or reg) tokens
  524. use_abs_pos_emb: Use absolute (learned) positional embeddings
  525. use_rot_pos_emb: Use rotary position embeddings
  526. rope_type: Type of RoPE to use ('cat', 'mixed', 'dinov3', etc.).
  527. rope_grid_offset: Offset for rotary position embedding grid
  528. rope_grid_indexing: Indexing mode for rotary position embeddings ('ij' or 'xy')
  529. rope_temperature: Temperature parameter for ROPE frequency computation
  530. rope_rotate_half: Use half rotation layout (rotate D/2 dims), else use interleaved rotation layout
  531. use_post_norm: Use post-norm transformer block type
  532. use_pre_transformer_norm: Use normalization layer before transformer blocks
  533. use_post_transformer_norm: Use normalization layer after transformer blocks
  534. use_fc_norm: Use normalization layer after pooling, before final classifier
  535. attn_pool_num_heads: Number of heads in attention pooling
  536. attn_pool_mlp_ratio: MLP ratio in attention pooling
  537. dynamic_img_size: Support dynamic image sizes in forward pass
  538. dynamic_img_pad: Apply dynamic padding for irregular image sizes
  539. ref_feat_shape: Reference feature shape for rotary position embedding scale
  540. head_init_scale: Initialization scale for classification head weights
  541. """
  542. super().__init__()
  543. dd = {'device': device, 'dtype': dtype}
  544. assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
  545. self.num_classes = num_classes
  546. self.global_pool = global_pool
  547. self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
  548. self.num_prefix_tokens = (1 if class_token else 0) + num_reg_tokens
  549. self.no_embed_class = no_embed_class
  550. self.dynamic_img_size = dynamic_img_size
  551. self.grad_checkpointing = False
  552. # resolve norm / pool usage
  553. activate_pre_norm = use_pre_transformer_norm
  554. if use_fc_norm is not None:
  555. activate_fc_norm = use_fc_norm # pass through if explicit
  556. else:
  557. activate_fc_norm = global_pool == 'avg' # default on if avg pool used
  558. if use_post_transformer_norm is not None:
  559. activate_post_norm = use_post_transformer_norm # pass through if explicit
  560. else:
  561. activate_post_norm = not activate_fc_norm # default on if fc_norm isn't active
  562. embed_args = {}
  563. if dynamic_img_size:
  564. # flatten deferred until after pos embed
  565. embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
  566. self.patch_embed = PatchEmbed(
  567. img_size=img_size,
  568. patch_size=patch_size,
  569. in_chans=in_chans,
  570. embed_dim=embed_dim,
  571. dynamic_img_pad=dynamic_img_pad,
  572. bias=not use_pre_transformer_norm,
  573. **embed_args,
  574. **dd,
  575. )
  576. num_patches = self.patch_embed.num_patches
  577. r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
  578. self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, **dd)) if class_token else None
  579. self.reg_token = nn.Parameter(torch.empty(1, num_reg_tokens, embed_dim, **dd)) if num_reg_tokens else None
  580. self.cls_embed = class_token and self.reg_token is None
  581. num_pos_tokens = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
  582. self.pos_embed = nn.Parameter(torch.empty(1, num_pos_tokens, embed_dim, **dd)) if use_abs_pos_emb else None
  583. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  584. if patch_drop_rate > 0:
  585. self.patch_drop = PatchDropoutWithIndices(patch_drop_rate, num_prefix_tokens=self.num_prefix_tokens)
  586. else:
  587. self.patch_drop = None
  588. self.rope_mixed = False
  589. if use_rot_pos_emb:
  590. ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None
  591. # Setup RoPE kwargs
  592. rope_kwargs = dict(
  593. dim=embed_dim,
  594. num_heads=num_heads,
  595. feat_shape=None if dynamic_img_size else self.patch_embed.grid_size,
  596. temperature=rope_temperature,
  597. grid_indexing=rope_grid_indexing,
  598. **dd,
  599. )
  600. if rope_type == 'mixed':
  601. rope_kwargs.update(dict(depth=depth))
  602. self.rope_mixed = True
  603. elif rope_type == 'cat':
  604. rope_kwargs.update(dict(
  605. in_pixels=False,
  606. grid_offset=rope_grid_offset,
  607. ref_feat_shape=ref_feat_shape,
  608. ))
  609. self.rope = create_rope_embed(rope_type=rope_type, **rope_kwargs)
  610. else:
  611. self.rope = None
  612. self.norm_pre = norm_layer(embed_dim, **dd) if activate_pre_norm else nn.Identity()
  613. dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
  614. block_fn = EvaBlockPostNorm if use_post_norm else EvaBlock
  615. self.blocks = nn.ModuleList([
  616. block_fn(
  617. dim=embed_dim,
  618. num_heads=num_heads,
  619. qkv_bias=qkv_bias,
  620. qkv_fused=qkv_fused,
  621. mlp_ratio=mlp_ratio,
  622. swiglu_mlp=swiglu_mlp,
  623. swiglu_align_to=swiglu_align_to,
  624. scale_mlp=scale_mlp,
  625. scale_attn_inner=scale_attn_inner,
  626. attn_type=attn_type,
  627. rotate_half=rope_rotate_half,
  628. num_prefix_tokens=self.num_prefix_tokens,
  629. proj_drop=proj_drop_rate,
  630. attn_drop=attn_drop_rate,
  631. drop_path=dpr[i],
  632. norm_layer=norm_layer,
  633. init_values=init_values,
  634. **dd,
  635. )
  636. for i in range(depth)])
  637. self.feature_info = [
  638. dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
  639. self.norm = norm_layer(embed_dim, **dd) if activate_post_norm else nn.Identity()
  640. if global_pool == 'map':
  641. self.attn_pool = AttentionPoolLatent(
  642. self.embed_dim,
  643. num_heads=attn_pool_num_heads or num_heads,
  644. mlp_ratio=attn_pool_mlp_ratio or mlp_ratio,
  645. norm_layer=norm_layer,
  646. act_layer=nn.GELU,
  647. **dd,
  648. )
  649. else:
  650. self.attn_pool = None
  651. self.fc_norm = norm_layer(embed_dim, **dd) if activate_fc_norm else nn.Identity()
  652. self.head_drop = nn.Dropout(drop_rate)
  653. self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  654. self.init_weights(head_init_scale=head_init_scale)
  655. def init_weights(self, head_init_scale=None):
  656. self.apply(self._init_weights)
  657. if self.pos_embed is not None:
  658. trunc_normal_(self.pos_embed, std=.02)
  659. if self.cls_token is not None:
  660. trunc_normal_(self.cls_token, std=.02)
  661. if self.reg_token is not None:
  662. trunc_normal_(self.reg_token, std=.02)
  663. self.fix_init_weight()
  664. if head_init_scale and isinstance(self.head, nn.Linear):
  665. trunc_normal_(self.head.weight, std=.02)
  666. self.head.weight.data.mul_(head_init_scale)
  667. self.head.bias.data.mul_(head_init_scale)
  668. def fix_init_weight(self) -> None:
  669. """Fix initialization weights by rescaling based on layer depth."""
  670. def rescale(param, layer_id):
  671. param.div_(math.sqrt(2.0 * layer_id))
  672. for layer_id, layer in enumerate(self.blocks):
  673. rescale(layer.attn.proj.weight.data, layer_id + 1)
  674. rescale(layer.mlp.fc2.weight.data, layer_id + 1)
  675. def _init_weights(self, m: nn.Module) -> None:
  676. """Initialize weights for Linear layers.
  677. Args:
  678. m: Module to initialize.
  679. """
  680. if isinstance(m, nn.Linear):
  681. trunc_normal_(m.weight, std=.02)
  682. if m.bias is not None:
  683. nn.init.zeros_(m.bias)
  684. @torch.jit.ignore
  685. def no_weight_decay(self) -> Set[str]:
  686. """Parameters to exclude from weight decay."""
  687. nwd = {'pos_embed', 'cls_token'}
  688. if (rope := getattr(self, "rope", None)) and hasattr(rope, "no_weight_decay"):
  689. return nwd | {f"rope.{p}" for p in rope.no_weight_decay()}
  690. return nwd
  691. @torch.jit.ignore
  692. def set_grad_checkpointing(self, enable: bool = True) -> None:
  693. """Enable or disable gradient checkpointing."""
  694. self.grad_checkpointing = enable
  695. @torch.jit.ignore
  696. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  697. """Create layer groupings for optimization."""
  698. matcher = dict(
  699. stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
  700. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))],
  701. )
  702. return matcher
  703. @torch.jit.ignore
  704. def get_classifier(self) -> nn.Module:
  705. return self.head
  706. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  707. """Reset the classifier head.
  708. Args:
  709. num_classes: Number of output classes.
  710. global_pool: Global pooling type.
  711. """
  712. self.num_classes = num_classes
  713. if global_pool is not None:
  714. self.global_pool = global_pool
  715. self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  716. def set_input_size(
  717. self,
  718. img_size: Optional[Tuple[int, int]] = None,
  719. patch_size: Optional[Tuple[int, int]] = None,
  720. ) -> None:
  721. """Update the input image resolution and patch size.
  722. Args:
  723. img_size: New input resolution, if None current resolution is used.
  724. patch_size: New patch size, if None existing patch size is used.
  725. """
  726. prev_grid_size = self.patch_embed.grid_size
  727. self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
  728. if self.pos_embed is not None:
  729. num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
  730. num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens
  731. if num_new_tokens != self.pos_embed.shape[1]:
  732. self.pos_embed = nn.Parameter(resample_abs_pos_embed(
  733. self.pos_embed,
  734. new_size=self.patch_embed.grid_size,
  735. old_size=prev_grid_size,
  736. num_prefix_tokens=num_prefix_tokens,
  737. verbose=True,
  738. ))
  739. if self.rope is not None:
  740. self.rope.update_feat_shape(self.patch_embed.grid_size)
  741. def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  742. if self.dynamic_img_size:
  743. B, H, W, C = x.shape
  744. if self.pos_embed is not None:
  745. prev_grid_size = self.patch_embed.grid_size
  746. pos_embed = resample_abs_pos_embed(
  747. self.pos_embed,
  748. new_size=(H, W),
  749. old_size=prev_grid_size,
  750. num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
  751. )
  752. else:
  753. pos_embed = None
  754. x = x.view(B, -1, C)
  755. rot_pos_embed = self.rope.get_embed(shape=(H, W)) if self.rope is not None else None
  756. else:
  757. pos_embed = self.pos_embed
  758. rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
  759. to_cat = []
  760. if self.cls_token is not None:
  761. to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
  762. if self.reg_token is not None:
  763. to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))
  764. if self.no_embed_class:
  765. # position embedding does not overlap with class / reg token
  766. if pos_embed is not None:
  767. x = x + pos_embed
  768. if to_cat:
  769. x = torch.cat(to_cat + [x], dim=1)
  770. else:
  771. # pos_embed has entry for class / reg token, concat then add
  772. if to_cat:
  773. x = torch.cat(to_cat + [x], dim=1)
  774. if pos_embed is not None:
  775. x = x + pos_embed
  776. x = self.pos_drop(x)
  777. # apply patch dropout to patches and rotary position embedding
  778. if self.patch_drop is not None:
  779. x, keep_indices = self.patch_drop(x)
  780. if rot_pos_embed is not None and keep_indices is not None:
  781. rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
  782. # After applying keep indices to rope embeds, batch dim is added
  783. if getattr(self, 'rope_mixed', False):
  784. # B, D, nH, N, dim -> D, B, nH, N, dim. For consistent iteration over depth at index 0.
  785. rot_pos_embed = rot_pos_embed.transpose(0, 1)
  786. else:
  787. # B, N, dim -> B, 1, N, dim. Need head dim singleton for correct dim alignment in axial mode.
  788. rot_pos_embed = rot_pos_embed.unsqueeze(1)
  789. return x, rot_pos_embed
  790. def forward_intermediates(
  791. self,
  792. x: torch.Tensor,
  793. indices: Optional[Union[int, List[int]]] = None,
  794. return_prefix_tokens: bool = False,
  795. norm: bool = False,
  796. stop_early: bool = False,
  797. output_fmt: str = 'NCHW',
  798. intermediates_only: bool = False,
  799. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  800. """ Forward features that returns intermediates.
  801. Args:
  802. x: Input image tensor
  803. indices: Take last n blocks if an int, if is a sequence, select by matching indices
  804. return_prefix_tokens: Return both prefix and spatial intermediate tokens
  805. norm: Apply norm layer to all intermediates
  806. stop_early: Stop iterating over blocks when last desired intermediate hit
  807. output_fmt: Shape of intermediate feature outputs
  808. intermediates_only: Only return intermediate features
  809. """
  810. assert output_fmt in ('NCHW', 'NLC'), 'Output format for EVA-ViT features must be one of NCHW or NLC.'
  811. reshape = output_fmt == 'NCHW'
  812. intermediates = []
  813. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  814. # forward pass
  815. B, _, height, width = x.shape
  816. x = self.patch_embed(x)
  817. x, rot_pos_embed = self._pos_embed(x)
  818. x = self.norm_pre(x)
  819. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  820. blocks = self.blocks
  821. else:
  822. blocks = self.blocks[:max_index + 1]
  823. # Handle depth-dependent embeddings for mixed mode
  824. if getattr(self, 'rope_mixed', False) and rot_pos_embed is not None:
  825. for i, blk in enumerate(blocks):
  826. if self.grad_checkpointing and not torch.jit.is_scripting():
  827. x = checkpoint(blk, x, rope=rot_pos_embed[i])
  828. else:
  829. x = blk(x, rope=rot_pos_embed[i])
  830. if i in take_indices:
  831. intermediates.append(self.norm(x) if norm else x)
  832. else:
  833. for i, blk in enumerate(blocks):
  834. if self.grad_checkpointing and not torch.jit.is_scripting():
  835. x = checkpoint(blk, x, rope=rot_pos_embed)
  836. else:
  837. x = blk(x, rope=rot_pos_embed)
  838. if i in take_indices:
  839. intermediates.append(self.norm(x) if norm else x)
  840. # process intermediates
  841. if self.num_prefix_tokens:
  842. # split prefix (e.g. class, distill) and spatial feature tokens
  843. prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
  844. intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
  845. if reshape:
  846. # reshape to BCHW output format
  847. H, W = self.patch_embed.dynamic_feat_size((height, width))
  848. intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
  849. if not torch.jit.is_scripting() and return_prefix_tokens:
  850. # return_prefix not support in torchscript due to poor type handling
  851. intermediates = list(zip(intermediates, prefix_tokens))
  852. if intermediates_only:
  853. return intermediates
  854. x = self.norm(x)
  855. return x, intermediates
  856. def prune_intermediate_layers(
  857. self,
  858. indices: Union[int, List[int]] = 1,
  859. prune_norm: bool = False,
  860. prune_head: bool = True,
  861. ):
  862. """ Prune layers not required for specified intermediates.
  863. """
  864. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  865. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  866. if prune_norm:
  867. self.norm = nn.Identity()
  868. if prune_head:
  869. self.attn_pool = None
  870. self.fc_norm = nn.Identity()
  871. self.reset_classifier(0, '')
  872. return take_indices
  873. def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
  874. if self.attn_pool is not None:
  875. x = self.attn_pool(x)
  876. return x
  877. pool_type = self.global_pool if pool_type is None else pool_type
  878. x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
  879. return x
  880. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  881. """Forward pass through feature extraction layers.
  882. Args:
  883. x: Input tensor.
  884. Returns:
  885. Feature tensor.
  886. """
  887. x = self.patch_embed(x)
  888. x, rot_pos_embed = self._pos_embed(x)
  889. x = self.norm_pre(x)
  890. if getattr(self, 'rope_mixed', False) and rot_pos_embed is not None:
  891. # Handle depth-dependent embeddings for mixed mode
  892. # pos embed has shape (depth, num_heads, H*W, dim) or (depth, batch_size, num_heads, H*W, dim)
  893. for i, blk in enumerate(self.blocks):
  894. if self.grad_checkpointing and not torch.jit.is_scripting():
  895. x = checkpoint(blk, x, rope=rot_pos_embed[i])
  896. else:
  897. x = blk(x, rope=rot_pos_embed[i])
  898. else:
  899. # Standard path for non-mixed mode
  900. for blk in self.blocks:
  901. if self.grad_checkpointing and not torch.jit.is_scripting():
  902. x = checkpoint(blk, x, rope=rot_pos_embed)
  903. else:
  904. x = blk(x, rope=rot_pos_embed)
  905. x = self.norm(x)
  906. return x
  907. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  908. """Forward pass through classifier head.
  909. Args:
  910. x: Feature tensor.
  911. pre_logits: Return pre-logits if True.
  912. Returns:
  913. Output tensor.
  914. """
  915. x = self.pool(x)
  916. x = self.fc_norm(x)
  917. x = self.head_drop(x)
  918. return x if pre_logits else self.head(x)
  919. def forward(self, x: torch.Tensor) -> torch.Tensor:
  920. """Forward pass.
  921. Args:
  922. x: Input tensor.
  923. Returns:
  924. Output tensor.
  925. """
  926. x = self.forward_features(x)
  927. x = self.forward_head(x)
  928. return x
  929. def _convert_pe(
  930. state_dict: Dict[str, torch.Tensor],
  931. model: nn.Module,
  932. prefix: str = 'visual.',
  933. ) -> Dict[str, torch.Tensor]:
  934. """Convert Perception Encoder weights.
  935. Args:
  936. state_dict: State dictionary to convert.
  937. model: Target model instance.
  938. prefix: Prefix to strip from keys.
  939. Returns:
  940. Converted state dictionary.
  941. """
  942. state_dict = state_dict.get('model', state_dict)
  943. state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
  944. out_dict = {}
  945. swaps = [
  946. ('conv1', 'patch_embed.proj'),
  947. ('positional_embedding', 'pos_embed'),
  948. ('transformer.resblocks.', 'blocks.'),
  949. ('ln_pre', 'norm_pre'),
  950. ('ln_post', 'norm'),
  951. ('ln_', 'norm'),
  952. ('ls_1.gamma', 'gamma_1'),
  953. ('ls_2.gamma', 'gamma_2'),
  954. ('in_proj_', 'qkv.'),
  955. ('out_proj', 'proj'),
  956. ('mlp.c_fc', 'mlp.fc1'),
  957. ('mlp.c_proj', 'mlp.fc2'),
  958. ]
  959. len_prefix = len(prefix)
  960. for k, v in state_dict.items():
  961. if prefix:
  962. if not k.startswith(prefix):
  963. continue
  964. k = k[len_prefix:]
  965. for sp in swaps:
  966. k = k.replace(sp[0], sp[1])
  967. if k.startswith('attn_pool'):
  968. k = k.replace('attn_pool.attn', 'attn_pool')
  969. k = k.replace('attn_pool.layernorm', 'attn_pool.norm')
  970. k = k.replace('attn_pool.probe', 'attn_pool.latent')
  971. if k.startswith('attn_pool.qkv'):
  972. dim = v.shape[0] // 3
  973. if k.endswith('weight'):
  974. out_dict['attn_pool.q.weight'] = v[:dim]
  975. out_dict['attn_pool.kv.weight'] = v[dim:]
  976. elif k.endswith('bias'):
  977. out_dict['attn_pool.q.bias'] = v[:dim]
  978. out_dict['attn_pool.kv.bias'] = v[dim:]
  979. continue
  980. elif k == 'proj':
  981. k = 'head.weight'
  982. v = v.transpose(0, 1)
  983. out_dict['head.bias'] = torch.zeros(v.shape[0])
  984. elif k == 'class_embedding':
  985. k = 'cls_token'
  986. v = v.unsqueeze(0).unsqueeze(1)
  987. elif k == 'pos_embed':
  988. v = v.unsqueeze(0)
  989. out_dict[k] = v
  990. return out_dict
  991. def checkpoint_filter_fn(
  992. state_dict: Dict[str, torch.Tensor],
  993. model: nn.Module,
  994. interpolation: str = 'bicubic',
  995. antialias: bool = True,
  996. ) -> Dict[str, torch.Tensor]:
  997. """Convert patch embedding weight from manual patchify + linear proj to conv.
  998. Args:
  999. state_dict: Checkpoint state dictionary.
  1000. model: Target model instance.
  1001. interpolation: Interpolation method for resizing.
  1002. antialias: Whether to use antialiasing when resizing.
  1003. Returns:
  1004. Filtered state dictionary.
  1005. """
  1006. out_dict = {}
  1007. # Standard EVA checkpoint processing
  1008. state_dict = state_dict.get('model_ema', state_dict)
  1009. state_dict = state_dict.get('model', state_dict)
  1010. state_dict = state_dict.get('module', state_dict)
  1011. state_dict = state_dict.get('state_dict', state_dict)
  1012. # Loading Meta PE (Perception Encoder) weights
  1013. if 'visual.conv1.weight' in state_dict:
  1014. return _convert_pe(state_dict, model)
  1015. elif 'conv1.weight' in state_dict:
  1016. return _convert_pe(state_dict, model, prefix='')
  1017. # prefix for loading OpenCLIP compatible weights
  1018. if 'visual.trunk.pos_embed' in state_dict:
  1019. prefix = 'visual.trunk.'
  1020. elif 'visual.pos_embed' in state_dict:
  1021. prefix = 'visual.'
  1022. else:
  1023. prefix = ''
  1024. dinov3_weights = 'storage_tokens' in state_dict
  1025. mim_weights = not dinov3_weights and prefix + 'mask_token' in state_dict
  1026. no_qkv = prefix + 'blocks.0.attn.q_proj.weight' in state_dict
  1027. len_prefix = len(prefix)
  1028. for k, v in state_dict.items():
  1029. if prefix:
  1030. if not k.startswith(prefix):
  1031. continue
  1032. k = k[len_prefix:]
  1033. if 'rope' in k and not k == 'rope.freqs':
  1034. # fixed embedding no need to load buffer from checkpoint
  1035. continue
  1036. if dinov3_weights:
  1037. if any([k.endswith(f) for f in ['.periods', '.bias_mask', 'mask_token']]):
  1038. # discard unused/non-persistent/pretrain only params
  1039. continue
  1040. if k.startswith('local_cls_norm'):
  1041. # discard, only used for 7b dinov3 pretrain w/ local crops
  1042. continue
  1043. if k.endswith('qkv.bias'):
  1044. q_bias_k = k.replace('qkv.bias', 'q_bias')
  1045. try:
  1046. # the distilled b,l,h models ended up with all zero biases, so timm
  1047. # has both qkv_bias=True and qkv_bias=False impl, test which
  1048. model.get_parameter(q_bias_k)
  1049. except Exception as e:
  1050. print(e)
  1051. # skip as target model has no bias parameter
  1052. continue
  1053. # split bias into components and skip the k as its supposed to be fixed at 0
  1054. qv, kv, vv = v.chunk(3, dim=-1)
  1055. out_dict[q_bias_k] = qv
  1056. out_dict[k.replace('qkv.bias', 'v_bias')] = vv
  1057. continue
  1058. k = k.replace('ls1.gamma', 'gamma_1') # match EVA ls naming
  1059. k = k.replace('ls2.gamma', 'gamma_2') # match EVA ls naming
  1060. k = k.replace('storage_tokens', 'reg_token') # rename storage to existing register naming
  1061. elif mim_weights and k in ('mask_token', 'lm_head.weight', 'lm_head.bias', 'norm.weight', 'norm.bias'):
  1062. if k == 'norm.weight' or k == 'norm.bias':
  1063. # try moving norm -> fc norm on fine-tune, probably a better starting point than new init
  1064. k = k.replace('norm', 'fc_norm')
  1065. else:
  1066. # skip pretrain mask token & head weights
  1067. continue
  1068. if 'patch_embed.proj.weight' in k:
  1069. _, _, H, W = model.patch_embed.proj.weight.shape
  1070. if v.shape[-1] != W or v.shape[-2] != H:
  1071. v = resample_patch_embed(
  1072. v,
  1073. (H, W),
  1074. interpolation=interpolation,
  1075. antialias=antialias,
  1076. verbose=True,
  1077. )
  1078. elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
  1079. # To resize pos embedding when using model at different size from pretrained weights
  1080. num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
  1081. v = resample_abs_pos_embed(
  1082. v,
  1083. new_size=model.patch_embed.grid_size,
  1084. num_prefix_tokens=num_prefix_tokens,
  1085. interpolation=interpolation,
  1086. antialias=antialias,
  1087. verbose=True,
  1088. )
  1089. k = k.replace('mlp.ffn_ln', 'mlp.norm')
  1090. k = k.replace('attn.inner_attn_ln', 'attn.norm')
  1091. k = k.replace('mlp.w12', 'mlp.fc1')
  1092. k = k.replace('mlp.w1', 'mlp.fc1_g')
  1093. k = k.replace('mlp.w2', 'mlp.fc1_x')
  1094. k = k.replace('mlp.w3', 'mlp.fc2')
  1095. if no_qkv:
  1096. k = k.replace('q_bias', 'q_proj.bias')
  1097. k = k.replace('v_bias', 'v_proj.bias')
  1098. out_dict[k] = v
  1099. return out_dict
  1100. def _create_eva(variant: str, pretrained: bool = False, **kwargs) -> Eva:
  1101. """Create an EVA model.
  1102. Args:
  1103. variant: Model variant name.
  1104. pretrained: Load pretrained weights.
  1105. **kwargs: Additional model arguments.
  1106. Returns:
  1107. Instantiated Eva model.
  1108. """
  1109. # Check if we should use NaFlexVit implementation
  1110. use_naflex = kwargs.pop('use_naflex', None)
  1111. _USE_NAFLEX_DEFAULT = os.environ.get('TIMM_USE_NAFLEX', '0') == '1'
  1112. if use_naflex is None:
  1113. use_naflex = _USE_NAFLEX_DEFAULT
  1114. if use_naflex:
  1115. # Import here to avoid circular import
  1116. from .naflexvit import _create_naflexvit_from_eva
  1117. return _create_naflexvit_from_eva(variant, pretrained, **kwargs)
  1118. out_indices = kwargs.pop('out_indices', 3)
  1119. model = build_model_with_cfg(
  1120. Eva, variant, pretrained,
  1121. pretrained_filter_fn=checkpoint_filter_fn,
  1122. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  1123. **kwargs,
  1124. )
  1125. return model
  1126. def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  1127. """Generate default configuration for EVA models.
  1128. Args:
  1129. url: Model weights URL.
  1130. **kwargs: Additional configuration parameters.
  1131. Returns:
  1132. Model configuration dictionary.
  1133. """
  1134. return {
  1135. 'url': url,
  1136. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  1137. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  1138. 'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD,
  1139. 'first_conv': 'patch_embed.proj', 'classifier': 'head',
  1140. 'license': 'mit', **kwargs
  1141. }
  1142. def _pe_cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  1143. """Generate default configuration for Perception Encoder models.
  1144. Args:
  1145. url: Model weights URL.
  1146. **kwargs: Additional configuration parameters.
  1147. Returns:
  1148. Model configuration dictionary.
  1149. """
  1150. return {
  1151. 'url': url,
  1152. 'num_classes': 0, 'input_size': (3, 224, 224), 'pool_size': None,
  1153. 'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True,
  1154. 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
  1155. 'first_conv': 'patch_embed.proj', 'classifier': 'head',
  1156. 'license': 'apache-2.0', **kwargs
  1157. }
  1158. def _dinov3_cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  1159. """Generate default configuration for DINOv3 models.
  1160. Args:
  1161. url: Model weights URL.
  1162. **kwargs: Additional configuration parameters.
  1163. Returns:
  1164. Model configuration dictionary.
  1165. """
  1166. return {
  1167. 'url': url,
  1168. 'num_classes': 0, 'input_size': (3, 256, 256), 'pool_size': None,
  1169. 'crop_pct': 1.0, 'interpolation': 'bicubic', 'fixed_input_size': True,
  1170. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  1171. 'first_conv': 'patch_embed.proj', 'classifier': 'head',
  1172. 'license': 'dinov3-license', **kwargs
  1173. }
  1174. default_cfgs = generate_default_cfgs({
  1175. # EVA 01 CLIP fine-tuned on imagenet-1k
  1176. 'eva_giant_patch14_224.clip_ft_in1k': _cfg(
  1177. # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
  1178. hf_hub_id='timm/',
  1179. ),
  1180. 'eva_giant_patch14_336.clip_ft_in1k': _cfg(
  1181. # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
  1182. hf_hub_id='timm/',
  1183. input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
  1184. # MIM EVA 01 pretrain, ft on in22k -> in1k
  1185. 'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
  1186. # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
  1187. hf_hub_id='timm/',
  1188. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1189. input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
  1190. 'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
  1191. # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
  1192. hf_hub_id='timm/',
  1193. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1194. input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
  1195. # in22k or m38m MIM pretrain w/ intermediate in22k fine-tune and final in1k fine-tune
  1196. 'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
  1197. # hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_B_pt_in21k_medft_in21k_ft_in1k_p14.pt',
  1198. hf_hub_id='timm/',
  1199. input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
  1200. ),
  1201. 'eva02_large_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
  1202. # hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_in21k_medft_in21k_ft_in1k_p14.pt',
  1203. hf_hub_id='timm/',
  1204. input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
  1205. ),
  1206. 'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k': _cfg(
  1207. hf_hub_id='timm/',
  1208. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_m38m_medft_in21k_ft_in1k_p14.pt',
  1209. input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
  1210. ),
  1211. # in22k or m3m MIM pretrain w/ in1k fine-tune
  1212. 'eva02_tiny_patch14_336.mim_in22k_ft_in1k': _cfg(
  1213. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_Ti_pt_in21k_ft_in1k_p14.pt',
  1214. hf_hub_id='timm/',
  1215. input_size=(3, 336, 336), crop_pct=1.0,
  1216. ),
  1217. 'eva02_small_patch14_336.mim_in22k_ft_in1k': _cfg(
  1218. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_S_pt_in21k_ft_in1k_p14.pt',
  1219. hf_hub_id='timm/',
  1220. input_size=(3, 336, 336), crop_pct=1.0,
  1221. ),
  1222. 'eva02_base_patch14_448.mim_in22k_ft_in1k': _cfg(
  1223. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_B_pt_in21k_ft_in1k_p14.pt',
  1224. hf_hub_id='timm/',
  1225. input_size=(3, 448, 448), crop_pct=1.0,
  1226. ),
  1227. 'eva02_large_patch14_448.mim_in22k_ft_in1k': _cfg(
  1228. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_L_pt_in21k_ft_in1k_p14.pt',
  1229. hf_hub_id='timm/',
  1230. input_size=(3, 448, 448), crop_pct=1.0,
  1231. ),
  1232. 'eva02_large_patch14_448.mim_m38m_ft_in1k': _cfg(
  1233. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_L_pt_m38m_ft_in1k_p14.pt',
  1234. hf_hub_id='timm/',
  1235. input_size=(3, 448, 448), crop_pct=1.0,
  1236. ),
  1237. # in22k or m3m MIM pretrain w/ in22k fine-tune
  1238. 'eva02_base_patch14_448.mim_in22k_ft_in22k': _cfg(
  1239. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_B_pt_in21k_medft_in21k_p14.pt',
  1240. hf_hub_id='timm/',
  1241. input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
  1242. ),
  1243. 'eva02_large_patch14_448.mim_in22k_ft_in22k': _cfg(
  1244. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_in21k_medft_in21k_p14.pt',
  1245. hf_hub_id='timm/',
  1246. input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
  1247. ),
  1248. 'eva02_large_patch14_448.mim_m38m_ft_in22k': _cfg(
  1249. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_m38m_medft_in21k_p14.pt',
  1250. hf_hub_id='timm/',
  1251. input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
  1252. ),
  1253. # in22k or m38m MIM pretrain
  1254. 'eva02_tiny_patch14_224.mim_in22k': _cfg(
  1255. # hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_Ti_pt_in21k_p14.pt',
  1256. hf_hub_id='timm/',
  1257. num_classes=0,
  1258. ),
  1259. 'eva02_small_patch14_224.mim_in22k': _cfg(
  1260. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_S_pt_in21k_p14.pt',
  1261. hf_hub_id='timm/',
  1262. num_classes=0,
  1263. ),
  1264. 'eva02_base_patch14_224.mim_in22k': _cfg(
  1265. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_B_pt_in21k_p14.pt',
  1266. hf_hub_id='timm/',
  1267. num_classes=0,
  1268. ),
  1269. 'eva02_large_patch14_224.mim_in22k': _cfg(
  1270. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_in21k_p14.pt',
  1271. hf_hub_id='timm/',
  1272. num_classes=0,
  1273. ),
  1274. 'eva02_large_patch14_224.mim_m38m': _cfg(
  1275. #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_m38m_p14.pt',
  1276. hf_hub_id='timm/',
  1277. num_classes=0,
  1278. ),
  1279. # EVA01 and EVA02 CLIP image towers
  1280. 'eva_giant_patch14_clip_224.laion400m': _cfg(
  1281. # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt',
  1282. # hf_hub_id='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k', # float16 weights
  1283. # hf_hub_filename='open_clip_pytorch_model.bin',
  1284. hf_hub_id='timm/',
  1285. num_classes=1024,
  1286. ),
  1287. 'eva_giant_patch14_clip_224.merged2b': _cfg(
  1288. # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt',
  1289. # hf_hub_id='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k', # float16 weights
  1290. # hf_hub_filename='open_clip_pytorch_model.bin',
  1291. hf_hub_id='timm/',
  1292. num_classes=1024,
  1293. ),
  1294. 'eva02_base_patch16_clip_224.merged2b': _cfg(
  1295. # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
  1296. # hf_hub_id='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k', # float16 weights
  1297. # hf_hub_filename='open_clip_pytorch_model.bin',
  1298. hf_hub_id='timm/',
  1299. num_classes=512,
  1300. ),
  1301. 'eva02_large_patch14_clip_224.merged2b': _cfg(
  1302. # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
  1303. # hf_hub_id='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k', # float16 weights
  1304. # hf_hub_filename='open_clip_pytorch_model.bin',
  1305. hf_hub_id='timm/',
  1306. num_classes=768,
  1307. ),
  1308. 'eva02_large_patch14_clip_336.merged2b': _cfg(
  1309. # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
  1310. # hf_hub_id='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k', # float16 weights
  1311. # hf_hub_filename='open_clip_pytorch_model.bin',
  1312. hf_hub_id='timm/',
  1313. input_size=(3, 336, 336), crop_pct=1.0,
  1314. num_classes=768,
  1315. ),
  1316. 'eva02_enormous_patch14_clip_224.laion2b': _cfg(
  1317. # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt',
  1318. # hf_hub_id='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k', # float16 weights
  1319. # hf_hub_filename='open_clip_pytorch_model.bin',
  1320. hf_hub_id='timm/',
  1321. num_classes=1024,
  1322. ),
  1323. 'eva02_enormous_patch14_clip_224.laion2b_plus': _cfg(
  1324. # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt',
  1325. # hf_hub_id='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k', # bfloat16 weights
  1326. # hf_hub_filename='open_clip_pytorch_model.bin',
  1327. hf_hub_id='timm/',
  1328. num_classes=1024,
  1329. ),
  1330. 'eva02_enormous_patch14_clip_224.pretrain': _cfg(
  1331. # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_E_psz14.pt',
  1332. num_classes=0,
  1333. ),
  1334. 'vit_medium_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
  1335. hf_hub_id='timm/',
  1336. input_size=(3, 256, 256), crop_pct=0.95,
  1337. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
  1338. ),
  1339. 'vit_mediumd_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
  1340. hf_hub_id='timm/',
  1341. input_size=(3, 256, 256), crop_pct=0.95,
  1342. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
  1343. ),
  1344. 'vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k': _cfg(
  1345. hf_hub_id='timm/',
  1346. input_size=(3, 256, 256), crop_pct=0.95,
  1347. ),
  1348. 'vit_base_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
  1349. hf_hub_id='timm/',
  1350. input_size=(3, 256, 256), crop_pct=0.95,
  1351. mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
  1352. ),
  1353. # Perception Encoder weights
  1354. 'vit_pe_core_tiny_patch16_384.fb': _pe_cfg(
  1355. hf_hub_id='timm/',
  1356. #hf_hub_id='facebook/PE-Core-T16-384',
  1357. #hf_hub_filename='PE-Core-T16-384.pt',
  1358. input_size=(3, 384, 384),
  1359. num_classes=512, # output proj dim
  1360. ),
  1361. 'vit_pe_core_small_patch16_384.fb': _pe_cfg(
  1362. hf_hub_id='timm/',
  1363. #hf_hub_id='facebook/PE-Core-S16-384',
  1364. #hf_hub_filename='PE-Core-S16-384.pt',
  1365. input_size=(3, 384, 384),
  1366. num_classes=512, # output proj dim
  1367. ),
  1368. 'vit_pe_core_base_patch16_224.fb': _pe_cfg(
  1369. hf_hub_id='timm/',
  1370. #hf_hub_id='facebook/PE-Core-B16-224',
  1371. #hf_hub_filename='PE-Core-B16-224.pt',
  1372. input_size=(3, 224, 224),
  1373. num_classes=1024, # output proj dim
  1374. ),
  1375. 'vit_pe_core_large_patch14_336.fb': _pe_cfg(
  1376. hf_hub_id='timm/',
  1377. #hf_hub_id='facebook/PE-Core-L14-336',
  1378. #hf_hub_filename='PE-Core-L14-336.pt',
  1379. input_size=(3, 336, 336),
  1380. num_classes=1024, # output proj dim
  1381. ),
  1382. 'vit_pe_core_gigantic_patch14_448.fb': _pe_cfg(
  1383. hf_hub_id='timm/',
  1384. #hf_hub_id='facebook/PE-Core-G14-448',
  1385. #hf_hub_filename='PE-Core-G14-448.pt',
  1386. input_size=(3, 448, 448),
  1387. num_classes=1280, # output proj dim
  1388. ),
  1389. 'vit_pe_lang_large_patch14_448.fb': _pe_cfg(
  1390. hf_hub_id='timm/',
  1391. #hf_hub_id='facebook/PE-Lang-L14-448',
  1392. #hf_hub_filename='PE-Lang-L14-448.pt',
  1393. input_size=(3, 448, 448),
  1394. num_classes=0,
  1395. ),
  1396. 'vit_pe_lang_large_patch14_448.fb_tiling': _pe_cfg(
  1397. hf_hub_id='timm/',
  1398. #hf_hub_id='facebook/PE-Lang-L14-448-Tiling',
  1399. #hf_hub_filename='PE-Lang-L14-448-Tiling.pt',
  1400. input_size=(3, 448, 448),
  1401. num_classes=0,
  1402. ),
  1403. 'vit_pe_lang_gigantic_patch14_448.fb': _pe_cfg(
  1404. hf_hub_id='timm/',
  1405. #hf_hub_id='facebook/PE-Lang-G14-448',
  1406. #hf_hub_filename='PE-Lang-G14-448.pt',
  1407. input_size=(3, 448, 448),
  1408. num_classes=0,
  1409. ),
  1410. 'vit_pe_lang_gigantic_patch14_448.fb_tiling': _pe_cfg(
  1411. hf_hub_id='timm/',
  1412. #hf_hub_id='facebook/PE-Lang-G14-448-Tiling',
  1413. #hf_hub_filename='PE-Lang-G14-448-Tiling.pt',
  1414. input_size=(3, 448, 448),
  1415. num_classes=0,
  1416. ),
  1417. 'vit_pe_spatial_tiny_patch16_512.fb': _pe_cfg(
  1418. hf_hub_id='timm/',
  1419. #hf_hub_id='facebook/PE-Spatial-T16-512',
  1420. #hf_hub_filename='PE-Spatial-T16-512.pt',
  1421. input_size=(3, 512, 512),
  1422. num_classes=0,
  1423. ),
  1424. 'vit_pe_spatial_small_patch16_512.fb': _pe_cfg(
  1425. hf_hub_id='timm/',
  1426. #hf_hub_id='facebook/PE-Spatial-S16-512',
  1427. #hf_hub_filename='PE-Spatial-S16-512.pt',
  1428. input_size=(3, 512, 512),
  1429. num_classes=0,
  1430. ),
  1431. 'vit_pe_spatial_base_patch16_512.fb': _pe_cfg(
  1432. hf_hub_id='timm/',
  1433. #hf_hub_id='facebook/PE-Spatial-B16-512',
  1434. #hf_hub_filename='PE-Spatial-B16-512.pt',
  1435. input_size=(3, 512, 512),
  1436. num_classes=0,
  1437. ),
  1438. 'vit_pe_spatial_large_patch14_448.fb': _pe_cfg(
  1439. hf_hub_id='timm/',
  1440. #hf_hub_id='facebook/PE-Spatial-L14-448',
  1441. #hf_hub_filename='PE-Spatial-L14-448.pt',
  1442. input_size=(3, 448, 448),
  1443. num_classes=0,
  1444. ),
  1445. 'vit_pe_spatial_gigantic_patch14_448.fb': _pe_cfg(
  1446. hf_hub_id='timm/',
  1447. #hf_hub_id='facebook/PE-Spatial-G14-448',
  1448. #hf_hub_filename='PE-Spatial-G14-448.pt',
  1449. input_size=(3, 448, 448),
  1450. num_classes=0,
  1451. ),
  1452. # RoPE-ViT models from Naver
  1453. 'vit_small_patch16_rope_224.naver_in1k': _cfg(
  1454. hf_hub_id='timm/',
  1455. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1456. license='apache-2.0',
  1457. ),
  1458. 'vit_base_patch16_rope_224.naver_in1k': _cfg(
  1459. hf_hub_id='timm/',
  1460. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1461. license='apache-2.0',
  1462. ),
  1463. 'vit_large_patch16_rope_224.naver_in1k': _cfg(
  1464. hf_hub_id='timm/',
  1465. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1466. license='apache-2.0',
  1467. ),
  1468. 'vit_small_patch16_rope_mixed_224.naver_in1k': _cfg(
  1469. hf_hub_id='timm/',
  1470. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1471. license='apache-2.0',
  1472. ),
  1473. 'vit_base_patch16_rope_mixed_224.naver_in1k': _cfg(
  1474. hf_hub_id='timm/',
  1475. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1476. license='apache-2.0',
  1477. ),
  1478. 'vit_large_patch16_rope_mixed_224.naver_in1k': _cfg(
  1479. hf_hub_id='timm/',
  1480. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1481. license='apache-2.0',
  1482. ),
  1483. 'vit_small_patch16_rope_ape_224.naver_in1k': _cfg(
  1484. hf_hub_id='timm/',
  1485. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1486. license='apache-2.0',
  1487. ),
  1488. 'vit_base_patch16_rope_ape_224.naver_in1k': _cfg(
  1489. hf_hub_id='timm/',
  1490. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1491. license='apache-2.0',
  1492. ),
  1493. 'vit_large_patch16_rope_ape_224.naver_in1k': _cfg(
  1494. hf_hub_id='timm/',
  1495. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1496. license='apache-2.0',
  1497. ),
  1498. 'vit_small_patch16_rope_mixed_ape_224.naver_in1k': _cfg(
  1499. hf_hub_id='timm/',
  1500. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1501. license='apache-2.0',
  1502. ),
  1503. 'vit_base_patch16_rope_mixed_ape_224.naver_in1k': _cfg(
  1504. hf_hub_id='timm/',
  1505. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1506. license='apache-2.0',
  1507. ),
  1508. 'vit_large_patch16_rope_mixed_ape_224.naver_in1k': _cfg(
  1509. hf_hub_id='timm/',
  1510. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
  1511. license='apache-2.0',
  1512. ),
  1513. # DINOv3 weights are under a specific license with redistribution terms, please see
  1514. # https://github.com/facebookresearch/dinov3/blob/main/LICENSE.md
  1515. 'vit_small_patch16_dinov3.lvd1689m': _dinov3_cfg(
  1516. hf_hub_id='timm/',
  1517. ),
  1518. 'vit_small_patch16_dinov3_qkvb.lvd1689m': _dinov3_cfg(
  1519. hf_hub_id='timm/',
  1520. ),
  1521. 'vit_small_plus_patch16_dinov3.lvd1689m': _dinov3_cfg(
  1522. hf_hub_id='timm/',
  1523. ),
  1524. 'vit_small_plus_patch16_dinov3_qkvb.lvd1689m': _dinov3_cfg(
  1525. hf_hub_id='timm/',
  1526. ),
  1527. 'vit_base_patch16_dinov3.lvd1689m': _dinov3_cfg(
  1528. hf_hub_id='timm/',
  1529. ),
  1530. 'vit_base_patch16_dinov3_qkvb.lvd1689m': _dinov3_cfg(
  1531. hf_hub_id='timm/',
  1532. ),
  1533. 'vit_large_patch16_dinov3.lvd1689m': _dinov3_cfg(
  1534. hf_hub_id='timm/',
  1535. ),
  1536. 'vit_large_patch16_dinov3_qkvb.lvd1689m': _dinov3_cfg(
  1537. hf_hub_id='timm/',
  1538. ),
  1539. 'vit_large_patch16_dinov3.sat493m': _dinov3_cfg(
  1540. hf_hub_id='timm/',
  1541. mean=(0.430, 0.411, 0.296), std=(0.213, 0.156, 0.143),
  1542. ),
  1543. 'vit_large_patch16_dinov3_qkvb.sat493m': _dinov3_cfg(
  1544. hf_hub_id='timm/',
  1545. mean=(0.430, 0.411, 0.296), std=(0.213, 0.156, 0.143),
  1546. ),
  1547. 'vit_huge_plus_patch16_dinov3.lvd1689m': _dinov3_cfg(
  1548. hf_hub_id='timm/',
  1549. ),
  1550. 'vit_huge_plus_patch16_dinov3_qkvb.lvd1689m': _dinov3_cfg(
  1551. hf_hub_id='timm/',
  1552. ),
  1553. 'vit_7b_patch16_dinov3.lvd1689m': _dinov3_cfg(
  1554. hf_hub_id='timm/',
  1555. ),
  1556. 'vit_7b_patch16_dinov3.sat493m': _dinov3_cfg(
  1557. hf_hub_id='timm/',
  1558. mean=(0.430, 0.411, 0.296), std=(0.213, 0.156, 0.143),
  1559. ),
  1560. })
  1561. @register_model
  1562. def eva_giant_patch14_224(pretrained: bool = False, **kwargs) -> Eva:
  1563. """EVA-g model https://arxiv.org/abs/2211.07636"""
  1564. model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408)
  1565. model = _create_eva('eva_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1566. return model
  1567. @register_model
  1568. def eva_giant_patch14_336(pretrained: bool = False, **kwargs) -> Eva:
  1569. """EVA-g model https://arxiv.org/abs/2211.07636"""
  1570. model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408)
  1571. model = _create_eva('eva_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  1572. return model
  1573. @register_model
  1574. def eva_giant_patch14_560(pretrained: bool = False, **kwargs) -> Eva:
  1575. """EVA-g model https://arxiv.org/abs/2211.07636"""
  1576. model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408)
  1577. model = _create_eva('eva_giant_patch14_560', pretrained=pretrained, **dict(model_args, **kwargs))
  1578. return model
  1579. @register_model
  1580. def eva02_tiny_patch14_224(pretrained: bool = False, **kwargs) -> Eva:
  1581. """EVA02 Tiny https://arxiv.org/abs/2303.11331"""
  1582. model_args = dict(
  1583. img_size=224,
  1584. patch_size=14,
  1585. embed_dim=192,
  1586. depth=12,
  1587. num_heads=3,
  1588. mlp_ratio=4 * 2 / 3,
  1589. swiglu_mlp=True,
  1590. use_rot_pos_emb=True,
  1591. ref_feat_shape=(16, 16), # 224/14
  1592. )
  1593. model = _create_eva('eva02_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1594. return model
  1595. @register_model
  1596. def eva02_small_patch14_224(pretrained: bool = False, **kwargs) -> Eva:
  1597. """EVA02 Small https://arxiv.org/abs/2303.11331"""
  1598. model_args = dict(
  1599. img_size=224,
  1600. patch_size=14,
  1601. embed_dim=384,
  1602. depth=12,
  1603. num_heads=6,
  1604. mlp_ratio=4 * 2 / 3,
  1605. swiglu_mlp=True,
  1606. use_rot_pos_emb=True,
  1607. ref_feat_shape=(16, 16), # 224/14
  1608. )
  1609. model = _create_eva('eva02_small_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1610. return model
  1611. @register_model
  1612. def eva02_base_patch14_224(pretrained: bool = False, **kwargs) -> Eva:
  1613. """EVA02 Base https://arxiv.org/abs/2303.11331"""
  1614. model_args = dict(
  1615. img_size=224,
  1616. patch_size=14,
  1617. embed_dim=768,
  1618. depth=12,
  1619. num_heads=12,
  1620. qkv_fused=False,
  1621. mlp_ratio=4 * 2 / 3,
  1622. swiglu_mlp=True,
  1623. scale_mlp=True,
  1624. use_rot_pos_emb=True,
  1625. ref_feat_shape=(16, 16), # 224/14
  1626. )
  1627. model = _create_eva('eva02_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1628. return model
  1629. @register_model
  1630. def eva02_large_patch14_224(pretrained: bool = False, **kwargs) -> Eva:
  1631. """EVA02 Large https://arxiv.org/abs/2303.11331"""
  1632. model_args = dict(
  1633. img_size=224,
  1634. patch_size=14,
  1635. embed_dim=1024,
  1636. depth=24,
  1637. num_heads=16,
  1638. mlp_ratio=4 * 2 / 3,
  1639. qkv_fused=False,
  1640. swiglu_mlp=True,
  1641. scale_mlp=True,
  1642. use_rot_pos_emb=True,
  1643. ref_feat_shape=(16, 16), # 224/14
  1644. )
  1645. model = _create_eva('eva02_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1646. return model
  1647. @register_model
  1648. def eva02_tiny_patch14_336(pretrained: bool = False, **kwargs) -> Eva:
  1649. """EVA02 Tiny https://arxiv.org/abs/2303.11331"""
  1650. model_args = dict(
  1651. img_size=336,
  1652. patch_size=14,
  1653. embed_dim=192,
  1654. depth=12,
  1655. num_heads=3,
  1656. mlp_ratio=4 * 2 / 3,
  1657. swiglu_mlp=True,
  1658. use_rot_pos_emb=True,
  1659. ref_feat_shape=(16, 16), # 224/14
  1660. )
  1661. model = _create_eva('eva02_tiny_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  1662. return model
  1663. @register_model
  1664. def eva02_small_patch14_336(pretrained: bool = False, **kwargs) -> Eva:
  1665. """EVA02 Small https://arxiv.org/abs/2303.11331"""
  1666. model_args = dict(
  1667. img_size=336,
  1668. patch_size=14,
  1669. embed_dim=384,
  1670. depth=12,
  1671. num_heads=6,
  1672. mlp_ratio=4 * 2 / 3,
  1673. swiglu_mlp=True,
  1674. use_rot_pos_emb=True,
  1675. ref_feat_shape=(16, 16), # 224/14
  1676. )
  1677. model = _create_eva('eva02_small_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  1678. return model
  1679. @register_model
  1680. def eva02_base_patch14_448(pretrained: bool = False, **kwargs) -> Eva:
  1681. """EVA02 Base https://arxiv.org/abs/2303.11331"""
  1682. model_args = dict(
  1683. img_size=448,
  1684. patch_size=14,
  1685. embed_dim=768,
  1686. depth=12,
  1687. num_heads=12,
  1688. qkv_fused=False,
  1689. mlp_ratio=4 * 2 / 3,
  1690. swiglu_mlp=True,
  1691. scale_mlp=True,
  1692. use_rot_pos_emb=True,
  1693. ref_feat_shape=(16, 16), # 224/14
  1694. )
  1695. model = _create_eva('eva02_base_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  1696. return model
  1697. @register_model
  1698. def eva02_large_patch14_448(pretrained: bool = False, **kwargs) -> Eva:
  1699. """EVA02 Large https://arxiv.org/abs/2303.11331"""
  1700. model_args = dict(
  1701. img_size=448,
  1702. patch_size=14,
  1703. embed_dim=1024,
  1704. depth=24,
  1705. num_heads=16,
  1706. mlp_ratio=4 * 2 / 3,
  1707. qkv_fused=False,
  1708. swiglu_mlp=True,
  1709. scale_mlp=True,
  1710. use_rot_pos_emb=True,
  1711. ref_feat_shape=(16, 16), # 224/14
  1712. )
  1713. model = _create_eva('eva02_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  1714. return model
  1715. @register_model
  1716. def eva_giant_patch14_clip_224(pretrained: bool = False, **kwargs) -> Eva:
  1717. """EVA-g CLIP model (only difference from non-CLIP is the pooling)"""
  1718. model_args = dict(
  1719. patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408,
  1720. global_pool=kwargs.pop('global_pool', 'token'))
  1721. model = _create_eva('eva_giant_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1722. return model
  1723. @register_model
  1724. def eva02_base_patch16_clip_224(pretrained: bool = False, **kwargs) -> Eva:
  1725. """An EVA-CLIP specific variant that adds additional attn scale layer-norm to eva02_base"""
  1726. model_args = dict(
  1727. img_size=224,
  1728. patch_size=16,
  1729. embed_dim=768,
  1730. depth=12,
  1731. num_heads=12,
  1732. qkv_fused=False,
  1733. mlp_ratio=4 * 2 / 3,
  1734. swiglu_mlp=True,
  1735. scale_mlp=True,
  1736. scale_attn_inner=True,
  1737. use_rot_pos_emb=True,
  1738. ref_feat_shape=(16, 16), # 224/14
  1739. global_pool=kwargs.pop('global_pool', 'token'),
  1740. )
  1741. model = _create_eva('eva02_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1742. return model
  1743. @register_model
  1744. def eva02_large_patch14_clip_224(pretrained: bool = False, **kwargs) -> Eva:
  1745. """An EVA-CLIP specific variant that adds additional attn scale layer-norm to eva02_large"""
  1746. model_args = dict(
  1747. img_size=224,
  1748. patch_size=14,
  1749. embed_dim=1024,
  1750. depth=24,
  1751. num_heads=16,
  1752. mlp_ratio=4 * 2 / 3,
  1753. qkv_fused=False,
  1754. swiglu_mlp=True,
  1755. scale_mlp=True,
  1756. scale_attn_inner=True,
  1757. use_rot_pos_emb=True,
  1758. ref_feat_shape=(16, 16), # 224/14
  1759. global_pool=kwargs.pop('global_pool', 'token'),
  1760. )
  1761. model = _create_eva('eva02_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1762. return model
  1763. @register_model
  1764. def eva02_large_patch14_clip_336(pretrained: bool = False, **kwargs) -> Eva:
  1765. """An EVA-CLIP specific variant that adds additional attn scale layer-norm to eva02_large"""
  1766. model_args = dict(
  1767. img_size=336,
  1768. patch_size=14,
  1769. embed_dim=1024,
  1770. depth=24,
  1771. num_heads=16,
  1772. mlp_ratio=4 * 2 / 3,
  1773. qkv_fused=False,
  1774. swiglu_mlp=True,
  1775. scale_mlp=True,
  1776. scale_attn_inner=True,
  1777. use_rot_pos_emb=True,
  1778. ref_feat_shape=(16, 16), # 224/14
  1779. global_pool=kwargs.pop('global_pool', 'token'),
  1780. )
  1781. model = _create_eva('eva02_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs))
  1782. return model
  1783. @register_model
  1784. def eva02_enormous_patch14_clip_224(pretrained: bool = False, **kwargs) -> Eva:
  1785. """An EVA-CLIP specific variant that uses residual post-norm in blocks"""
  1786. model_args = dict(
  1787. img_size=224,
  1788. patch_size=14,
  1789. embed_dim=1792,
  1790. depth=64,
  1791. num_heads=16,
  1792. mlp_ratio=15360 / 1792,
  1793. use_post_norm=True,
  1794. global_pool=kwargs.pop('global_pool', 'token'),
  1795. )
  1796. model = _create_eva('eva02_enormous_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1797. return model
  1798. @register_model
  1799. def vit_medium_patch16_rope_reg1_gap_256(pretrained: bool = False, **kwargs) -> Eva:
  1800. """timm SBB ViT with ROPE"""
  1801. model_args = dict(
  1802. img_size=256,
  1803. patch_size=16,
  1804. embed_dim=512,
  1805. depth=12,
  1806. num_heads=8,
  1807. qkv_fused=True,
  1808. qkv_bias=True,
  1809. init_values=1e-5,
  1810. class_token=False,
  1811. num_reg_tokens=1,
  1812. use_rot_pos_emb=True,
  1813. use_abs_pos_emb=False,
  1814. ref_feat_shape=(16, 16), # 224/14
  1815. )
  1816. model = _create_eva('vit_medium_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  1817. return model
  1818. @register_model
  1819. def vit_mediumd_patch16_rope_reg1_gap_256(pretrained: bool = False, **kwargs) -> Eva:
  1820. """timm SBB ViT with ROPE"""
  1821. model_args = dict(
  1822. img_size=256,
  1823. patch_size=16,
  1824. embed_dim=512,
  1825. depth=20,
  1826. num_heads=8,
  1827. qkv_fused=True,
  1828. qkv_bias=False,
  1829. init_values=1e-5,
  1830. class_token=False,
  1831. num_reg_tokens=1,
  1832. use_rot_pos_emb=True,
  1833. use_abs_pos_emb=False,
  1834. ref_feat_shape=(16, 16), # 224/14
  1835. )
  1836. model = _create_eva('vit_mediumd_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  1837. return model
  1838. @register_model
  1839. def vit_betwixt_patch16_rope_reg4_gap_256(pretrained: bool = False, **kwargs) -> Eva:
  1840. """timm SBB ViT with ROPE"""
  1841. model_args = dict(
  1842. img_size=256,
  1843. patch_size=16,
  1844. embed_dim=640,
  1845. depth=12,
  1846. num_heads=10,
  1847. qkv_fused=True,
  1848. qkv_bias=True,
  1849. init_values=1e-5,
  1850. class_token=False,
  1851. num_reg_tokens=4,
  1852. use_rot_pos_emb=True,
  1853. use_abs_pos_emb=False,
  1854. ref_feat_shape=(16, 16), # 224/14
  1855. )
  1856. model = _create_eva('vit_betwixt_patch16_rope_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  1857. return model
  1858. @register_model
  1859. def vit_base_patch16_rope_reg1_gap_256(pretrained: bool = False, **kwargs) -> Eva:
  1860. """timm SBB ViT with ROPE"""
  1861. model_args = dict(
  1862. img_size=256,
  1863. patch_size=16,
  1864. embed_dim=768,
  1865. depth=12,
  1866. num_heads=12,
  1867. qkv_fused=True,
  1868. qkv_bias=True,
  1869. init_values=1e-5,
  1870. class_token=False,
  1871. num_reg_tokens=1,
  1872. use_rot_pos_emb=True,
  1873. use_abs_pos_emb=False,
  1874. ref_feat_shape=(16, 16), # 224/14
  1875. )
  1876. model = _create_eva('vit_base_patch16_rope_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
  1877. return model
  1878. @register_model
  1879. def vit_pe_core_tiny_patch16_384(pretrained: bool = False, **kwargs) -> Eva:
  1880. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  1881. model_args = dict(
  1882. patch_size=16,
  1883. embed_dim=192,
  1884. depth=12,
  1885. num_heads=3,
  1886. mlp_ratio=4.0,
  1887. global_pool='map',
  1888. attn_type='rope',
  1889. use_pre_transformer_norm=True,
  1890. use_rot_pos_emb=True,
  1891. ref_feat_shape=(24, 24),
  1892. rope_grid_offset=1.,
  1893. rope_grid_indexing='xy',
  1894. attn_pool_num_heads=8,
  1895. attn_pool_mlp_ratio=4.,
  1896. norm_layer=partial(LayerNorm, eps=1e-5),
  1897. #dynamic_img_size=True
  1898. )
  1899. return _create_eva('vit_pe_core_tiny_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  1900. @register_model
  1901. def vit_pe_core_small_patch16_384(pretrained: bool = False, **kwargs) -> Eva:
  1902. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  1903. model_args = dict(
  1904. patch_size=16,
  1905. embed_dim=384,
  1906. depth=12,
  1907. num_heads=6,
  1908. mlp_ratio=4.0,
  1909. global_pool='map',
  1910. attn_type='rope',
  1911. use_pre_transformer_norm=True,
  1912. use_rot_pos_emb=True,
  1913. ref_feat_shape=(24, 24),
  1914. rope_grid_offset=1.,
  1915. rope_grid_indexing='xy',
  1916. attn_pool_num_heads=8,
  1917. attn_pool_mlp_ratio=4.,
  1918. norm_layer=partial(LayerNorm, eps=1e-5),
  1919. #dynamic_img_size=True
  1920. )
  1921. return _create_eva('vit_pe_core_small_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  1922. @register_model
  1923. def vit_pe_core_base_patch16_224(pretrained: bool = False, **kwargs) -> Eva:
  1924. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  1925. model_args = dict(
  1926. patch_size=16,
  1927. embed_dim=768,
  1928. depth=12,
  1929. num_heads=12,
  1930. mlp_ratio=4.0,
  1931. global_pool='map',
  1932. attn_type='rope',
  1933. use_pre_transformer_norm=True,
  1934. use_rot_pos_emb=True,
  1935. ref_feat_shape=(14, 14),
  1936. rope_grid_offset=1.,
  1937. rope_grid_indexing='xy',
  1938. attn_pool_num_heads=8,
  1939. attn_pool_mlp_ratio=4.,
  1940. norm_layer=partial(LayerNorm, eps=1e-5),
  1941. #dynamic_img_size=True
  1942. )
  1943. return _create_eva('vit_pe_core_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  1944. @register_model
  1945. def vit_pe_core_large_patch14_336(pretrained: bool = False, **kwargs) -> Eva:
  1946. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  1947. model_args = dict(
  1948. patch_size=14,
  1949. embed_dim=1024,
  1950. depth=24,
  1951. num_heads=16,
  1952. mlp_ratio=4.0,
  1953. global_pool='map',
  1954. attn_type='rope',
  1955. use_pre_transformer_norm=True,
  1956. use_rot_pos_emb=True,
  1957. ref_feat_shape=(24, 24),
  1958. rope_grid_offset=1.,
  1959. rope_grid_indexing='xy',
  1960. attn_pool_num_heads=8,
  1961. attn_pool_mlp_ratio=4.,
  1962. norm_layer=partial(LayerNorm, eps=1e-5),
  1963. #dynamic_img_size=True,
  1964. )
  1965. return _create_eva('vit_pe_core_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
  1966. @register_model
  1967. def vit_pe_core_gigantic_patch14_448(pretrained: bool = False, **kwargs) -> Eva:
  1968. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  1969. model_args = dict(
  1970. patch_size=14,
  1971. embed_dim=1536,
  1972. depth=50,
  1973. num_heads=16,
  1974. mlp_ratio=8960 / 1536,
  1975. global_pool='map',
  1976. attn_type='rope',
  1977. class_token=False,
  1978. use_pre_transformer_norm=True,
  1979. use_rot_pos_emb=True,
  1980. ref_feat_shape=(32, 32),
  1981. rope_grid_indexing='xy',
  1982. attn_pool_num_heads=8,
  1983. attn_pool_mlp_ratio=4.,
  1984. norm_layer=partial(LayerNorm, eps=1e-5),
  1985. #dynamic_img_size=True,
  1986. )
  1987. return _create_eva('vit_pe_core_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  1988. @register_model
  1989. def vit_pe_lang_large_patch14_448(pretrained: bool = False, **kwargs) -> Eva:
  1990. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  1991. model_args = dict(
  1992. patch_size=14,
  1993. embed_dim=1024,
  1994. depth=23,
  1995. num_heads=16,
  1996. mlp_ratio=4.0,
  1997. attn_type='rope',
  1998. class_token=True,
  1999. use_rot_pos_emb=True,
  2000. ref_feat_shape=(32, 32),
  2001. rope_grid_offset=1.,
  2002. rope_grid_indexing='xy',
  2003. use_pre_transformer_norm=True,
  2004. use_post_transformer_norm=False,
  2005. use_fc_norm=False, # explicitly disable
  2006. init_values=0.1,
  2007. norm_layer=partial(LayerNorm, eps=1e-5),
  2008. #dynamic_img_size=True,
  2009. )
  2010. return _create_eva('vit_pe_lang_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  2011. @register_model
  2012. def vit_pe_lang_gigantic_patch14_448(pretrained: bool = False, **kwargs) -> Eva:
  2013. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  2014. model_args = dict(
  2015. patch_size=14,
  2016. embed_dim=1536,
  2017. depth=47,
  2018. num_heads=16,
  2019. mlp_ratio=8960 / 1536,
  2020. attn_type='rope',
  2021. class_token=False,
  2022. use_rot_pos_emb=True,
  2023. ref_feat_shape=(32, 32),
  2024. rope_grid_indexing='xy',
  2025. use_pre_transformer_norm=True,
  2026. use_post_transformer_norm=False,
  2027. use_fc_norm=False, # explicitly disable
  2028. init_values=0.1,
  2029. norm_layer=partial(LayerNorm, eps=1e-5),
  2030. #dynamic_img_size=True,
  2031. )
  2032. return _create_eva('vit_pe_lang_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  2033. @register_model
  2034. def vit_pe_spatial_tiny_patch16_512(pretrained: bool = False, **kwargs) -> Eva:
  2035. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  2036. model_args = dict(
  2037. patch_size=16,
  2038. embed_dim=192,
  2039. depth=12,
  2040. num_heads=3,
  2041. mlp_ratio=4.0,
  2042. attn_type='rope',
  2043. use_pre_transformer_norm=True,
  2044. use_post_transformer_norm=False,
  2045. use_fc_norm=False, # explicitly disable
  2046. use_rot_pos_emb=True,
  2047. ref_feat_shape=(32, 32),
  2048. rope_grid_offset=1.,
  2049. rope_grid_indexing='xy',
  2050. norm_layer=partial(LayerNorm, eps=1e-5),
  2051. #dynamic_img_size=True
  2052. )
  2053. return _create_eva('vit_pe_spatial_tiny_patch16_512', pretrained=pretrained, **dict(model_args, **kwargs))
  2054. @register_model
  2055. def vit_pe_spatial_small_patch16_512(pretrained: bool = False, **kwargs) -> Eva:
  2056. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  2057. model_args = dict(
  2058. patch_size=16,
  2059. embed_dim=384,
  2060. depth=12,
  2061. num_heads=6,
  2062. mlp_ratio=4.0,
  2063. attn_type='rope',
  2064. use_pre_transformer_norm=True,
  2065. use_post_transformer_norm=False,
  2066. use_fc_norm=False, # explicitly disable
  2067. use_rot_pos_emb=True,
  2068. ref_feat_shape=(32, 32),
  2069. rope_grid_offset=1.,
  2070. rope_grid_indexing='xy',
  2071. norm_layer=partial(LayerNorm, eps=1e-5),
  2072. #dynamic_img_size=True
  2073. )
  2074. return _create_eva('vit_pe_spatial_small_patch16_512', pretrained=pretrained, **dict(model_args, **kwargs))
  2075. @register_model
  2076. def vit_pe_spatial_base_patch16_512(pretrained: bool = False, **kwargs) -> Eva:
  2077. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  2078. model_args = dict(
  2079. patch_size=16,
  2080. embed_dim=768,
  2081. depth=12,
  2082. num_heads=12,
  2083. mlp_ratio=4.0,
  2084. attn_type='rope',
  2085. use_pre_transformer_norm=True,
  2086. use_post_transformer_norm=False,
  2087. use_fc_norm=False, # explicitly disable
  2088. use_rot_pos_emb=True,
  2089. ref_feat_shape=(32, 32),
  2090. rope_grid_offset=1.,
  2091. rope_grid_indexing='xy',
  2092. norm_layer=partial(LayerNorm, eps=1e-5),
  2093. #dynamic_img_size=True
  2094. )
  2095. return _create_eva('vit_pe_spatial_base_patch16_512', pretrained=pretrained, **dict(model_args, **kwargs))
  2096. @register_model
  2097. def vit_pe_spatial_large_patch14_448(pretrained: bool = False, **kwargs) -> Eva:
  2098. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  2099. model_args = dict(
  2100. patch_size=14,
  2101. embed_dim=1024,
  2102. depth=24,
  2103. num_heads=16,
  2104. mlp_ratio=4.0,
  2105. attn_type='rope',
  2106. use_pre_transformer_norm=True,
  2107. use_post_transformer_norm=False,
  2108. use_fc_norm=False, # explicitly disable
  2109. use_rot_pos_emb=True,
  2110. ref_feat_shape=(32, 32),
  2111. rope_grid_offset=1.,
  2112. rope_grid_indexing='xy',
  2113. norm_layer=partial(LayerNorm, eps=1e-5),
  2114. #dynamic_img_size=True,
  2115. )
  2116. return _create_eva('vit_pe_spatial_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  2117. @register_model
  2118. def vit_pe_spatial_gigantic_patch14_448(pretrained: bool = False, **kwargs) -> Eva:
  2119. """Perception Encoder (PE) ViT from Meta (https://arxiv.org/abs/2504.13181)"""
  2120. model_args = dict(
  2121. patch_size=14,
  2122. embed_dim=1536,
  2123. depth=50,
  2124. num_heads=16,
  2125. mlp_ratio=8960 / 1536,
  2126. attn_type='rope',
  2127. class_token=False,
  2128. use_rot_pos_emb=True,
  2129. ref_feat_shape=(32, 32),
  2130. rope_grid_indexing='xy',
  2131. use_pre_transformer_norm=True,
  2132. use_post_transformer_norm=False,
  2133. use_fc_norm=False, # explicitly disable
  2134. init_values=0.1,
  2135. norm_layer=partial(LayerNorm, eps=1e-5),
  2136. #dynamic_img_size=True,
  2137. )
  2138. return _create_eva('vit_pe_spatial_gigantic_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
  2139. # RoPE-ViT models from https://github.com/naver-ai/rope-vit
  2140. @register_model
  2141. def vit_small_patch16_rope_224(pretrained: bool = False, **kwargs) -> Eva:
  2142. """RoPE-Axial ViT-S/16 from https://github.com/naver-ai/rope-vit"""
  2143. model_args = dict(
  2144. patch_size=16,
  2145. embed_dim=384,
  2146. depth=12,
  2147. num_heads=6,
  2148. mlp_ratio=4,
  2149. attn_type='rope',
  2150. qkv_bias=True,
  2151. init_values=1e-5,
  2152. class_token=True,
  2153. global_pool='token',
  2154. use_abs_pos_emb=False,
  2155. use_rot_pos_emb=True,
  2156. rope_grid_indexing='xy',
  2157. rope_temperature=100.0,
  2158. )
  2159. model = _create_eva('vit_small_patch16_rope_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2160. return model
  2161. @register_model
  2162. def vit_base_patch16_rope_224(pretrained: bool = False, **kwargs) -> Eva:
  2163. """RoPE-Axial ViT-B/16 from https://github.com/naver-ai/rope-vit"""
  2164. model_args = dict(
  2165. patch_size=16,
  2166. embed_dim=768,
  2167. depth=12,
  2168. num_heads=12,
  2169. mlp_ratio=4,
  2170. attn_type='rope',
  2171. use_fc_norm=False,
  2172. qkv_bias=True,
  2173. init_values=1e-5,
  2174. class_token=True,
  2175. global_pool='token',
  2176. use_abs_pos_emb=False,
  2177. use_rot_pos_emb=True,
  2178. rope_grid_indexing='xy',
  2179. rope_temperature=100.0,
  2180. )
  2181. model = _create_eva('vit_base_patch16_rope_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2182. return model
  2183. @register_model
  2184. def vit_large_patch16_rope_224(pretrained: bool = False, **kwargs) -> Eva:
  2185. """RoPE-Axial ViT-L/16 from https://github.com/naver-ai/rope-vit"""
  2186. model_args = dict(
  2187. patch_size=16,
  2188. embed_dim=1024,
  2189. depth=24,
  2190. num_heads=16,
  2191. mlp_ratio=4,
  2192. attn_type='rope',
  2193. qkv_bias=True,
  2194. init_values=1e-5,
  2195. class_token=True,
  2196. global_pool='token',
  2197. use_abs_pos_emb=False,
  2198. use_rot_pos_emb=True,
  2199. rope_grid_indexing='xy',
  2200. rope_temperature=100.0,
  2201. )
  2202. model = _create_eva('vit_large_patch16_rope_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2203. return model
  2204. @register_model
  2205. def vit_small_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva:
  2206. """RoPE-Mixed ViT-S/16 from https://github.com/naver-ai/rope-vit"""
  2207. model_args = dict(
  2208. patch_size=16,
  2209. embed_dim=384,
  2210. depth=12,
  2211. num_heads=6,
  2212. mlp_ratio=4,
  2213. attn_type='rope',
  2214. qkv_bias=True,
  2215. init_values=1e-5,
  2216. class_token=True,
  2217. global_pool='token',
  2218. use_abs_pos_emb=False,
  2219. use_rot_pos_emb=True,
  2220. rope_grid_indexing='xy',
  2221. rope_temperature=10.0,
  2222. rope_type='mixed'
  2223. )
  2224. model = _create_eva('vit_small_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2225. return model
  2226. @register_model
  2227. def vit_base_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva:
  2228. """RoPE-Mixed ViT-B/16 from https://github.com/naver-ai/rope-vit"""
  2229. model_args = dict(
  2230. patch_size=16,
  2231. embed_dim=768,
  2232. depth=12,
  2233. num_heads=12,
  2234. mlp_ratio=4,
  2235. qkv_bias=True,
  2236. attn_type='rope',
  2237. init_values=1e-5,
  2238. class_token=True,
  2239. global_pool='token',
  2240. use_abs_pos_emb=False,
  2241. use_rot_pos_emb=True,
  2242. rope_grid_indexing='xy',
  2243. rope_temperature=10.0,
  2244. rope_type='mixed'
  2245. )
  2246. model = _create_eva('vit_base_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2247. return model
  2248. @register_model
  2249. def vit_large_patch16_rope_mixed_224(pretrained: bool = False, **kwargs) -> Eva:
  2250. """RoPE-Mixed ViT-L/16 from https://github.com/naver-ai/rope-vit"""
  2251. model_args = dict(
  2252. patch_size=16,
  2253. embed_dim=1024,
  2254. depth=24,
  2255. num_heads=16,
  2256. mlp_ratio=4,
  2257. attn_type='rope',
  2258. qkv_bias=True,
  2259. init_values=1e-5,
  2260. class_token=True,
  2261. global_pool='token',
  2262. use_abs_pos_emb=False,
  2263. use_rot_pos_emb=True,
  2264. rope_grid_indexing='xy',
  2265. rope_temperature=10.0,
  2266. rope_type='mixed'
  2267. )
  2268. model = _create_eva('vit_large_patch16_rope_mixed_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2269. return model
  2270. # APE variants (with absolute position embeddings)
  2271. @register_model
  2272. def vit_small_patch16_rope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
  2273. """RoPE-Axial + APE ViT-S/16 from https://github.com/naver-ai/rope-vit"""
  2274. model_args = dict(
  2275. patch_size=16,
  2276. embed_dim=384,
  2277. depth=12,
  2278. num_heads=6,
  2279. mlp_ratio=4,
  2280. attn_type='rope',
  2281. qkv_bias=True,
  2282. init_values=1e-5,
  2283. class_token=True,
  2284. global_pool='token',
  2285. no_embed_class=True,
  2286. use_abs_pos_emb=True,
  2287. use_rot_pos_emb=True,
  2288. rope_grid_indexing='xy',
  2289. rope_temperature=100.0,
  2290. )
  2291. model = _create_eva('vit_small_patch16_rope_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2292. return model
  2293. @register_model
  2294. def vit_base_patch16_rope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
  2295. """RoPE-Axial + APE ViT-B/16 from https://github.com/naver-ai/rope-vit"""
  2296. model_args = dict(
  2297. patch_size=16,
  2298. embed_dim=768,
  2299. depth=12,
  2300. num_heads=12,
  2301. mlp_ratio=4,
  2302. attn_type='rope',
  2303. qkv_bias=True,
  2304. init_values=1e-5,
  2305. class_token=True,
  2306. global_pool='token',
  2307. no_embed_class=True,
  2308. use_abs_pos_emb=True,
  2309. use_rot_pos_emb=True,
  2310. rope_grid_indexing='xy',
  2311. rope_temperature=100.0,
  2312. )
  2313. model = _create_eva('vit_base_patch16_rope_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2314. return model
  2315. @register_model
  2316. def vit_large_patch16_rope_ape_224(pretrained: bool = False, **kwargs) -> Eva:
  2317. """RoPE-Axial + APE ViT-L/16 from https://github.com/naver-ai/rope-vit"""
  2318. model_args = dict(
  2319. patch_size=16,
  2320. embed_dim=1024,
  2321. depth=24,
  2322. num_heads=16,
  2323. mlp_ratio=4,
  2324. attn_type='rope',
  2325. qkv_bias=True,
  2326. init_values=1e-5,
  2327. class_token=True,
  2328. global_pool='token',
  2329. no_embed_class=True,
  2330. use_abs_pos_emb=True,
  2331. use_rot_pos_emb=True,
  2332. rope_grid_indexing='xy',
  2333. rope_temperature=100.0,
  2334. )
  2335. model = _create_eva('vit_large_patch16_rope_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2336. return model
  2337. @register_model
  2338. def vit_small_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> Eva:
  2339. """RoPE-Mixed + APE ViT-S/16 from https://github.com/naver-ai/rope-vit"""
  2340. model_args = dict(
  2341. patch_size=16,
  2342. embed_dim=384,
  2343. depth=12,
  2344. num_heads=6,
  2345. mlp_ratio=4,
  2346. attn_type='rope',
  2347. qkv_bias=True,
  2348. init_values=1e-5,
  2349. class_token=True,
  2350. global_pool='token',
  2351. no_embed_class=True,
  2352. use_abs_pos_emb=True,
  2353. use_rot_pos_emb=True,
  2354. rope_grid_indexing='xy',
  2355. rope_temperature=10.0,
  2356. rope_type='mixed'
  2357. )
  2358. model = _create_eva('vit_small_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2359. return model
  2360. @register_model
  2361. def vit_base_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> Eva:
  2362. """RoPE-Mixed + APE ViT-B/16 from https://github.com/naver-ai/rope-vit"""
  2363. model_args = dict(
  2364. patch_size=16,
  2365. embed_dim=768,
  2366. depth=12,
  2367. num_heads=12,
  2368. mlp_ratio=4,
  2369. attn_type='rope',
  2370. qkv_bias=True,
  2371. init_values=1e-5,
  2372. class_token=True,
  2373. global_pool='token',
  2374. no_embed_class=True,
  2375. use_abs_pos_emb=True,
  2376. use_rot_pos_emb=True,
  2377. rope_grid_indexing='xy',
  2378. rope_temperature=10.0,
  2379. rope_type='mixed'
  2380. )
  2381. model = _create_eva('vit_base_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2382. return model
  2383. @register_model
  2384. def vit_large_patch16_rope_mixed_ape_224(pretrained: bool = False, **kwargs) -> Eva:
  2385. """RoPE-Mixed + APE ViT-L/16 from https://github.com/naver-ai/rope-vit"""
  2386. model_args = dict(
  2387. patch_size=16,
  2388. embed_dim=1024,
  2389. depth=24,
  2390. num_heads=16,
  2391. mlp_ratio=4,
  2392. attn_type='rope',
  2393. qkv_bias=True,
  2394. init_values=1e-5,
  2395. class_token=True,
  2396. global_pool='token',
  2397. no_embed_class=True,
  2398. use_abs_pos_emb=True,
  2399. use_rot_pos_emb=True,
  2400. rope_grid_indexing='xy',
  2401. rope_temperature=10.0,
  2402. rope_type='mixed'
  2403. )
  2404. model = _create_eva('vit_large_patch16_rope_mixed_ape_224', pretrained=pretrained, **dict(model_args, **kwargs))
  2405. return model
  2406. @register_model
  2407. def vit_small_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva:
  2408. """DINOv3 S/16 https://arxiv.org/abs/2508.10104"""
  2409. model_args = dict(
  2410. patch_size=16,
  2411. dynamic_img_size=True,
  2412. embed_dim=384,
  2413. depth=12,
  2414. num_heads=6,
  2415. qkv_bias=False,
  2416. init_values=1.0e-05, # layer-scale
  2417. rope_type='dinov3',
  2418. rope_temperature=100,
  2419. #rope_rescale_coords=2, # haven't added to interface
  2420. rope_rotate_half=True,
  2421. use_rot_pos_emb=True,
  2422. use_abs_pos_emb=False,
  2423. num_reg_tokens=4,
  2424. use_fc_norm=False,
  2425. norm_layer=partial(LayerNorm, eps=1e-5),
  2426. )
  2427. model = _create_eva('vit_small_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs))
  2428. return model
  2429. @register_model
  2430. def vit_small_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva:
  2431. """DINOv3 S/16 w/ QKV bias enabled (but zero) https://arxiv.org/abs/2508.10104"""
  2432. model_args = dict(
  2433. patch_size=16,
  2434. dynamic_img_size=True,
  2435. embed_dim=384,
  2436. depth=12,
  2437. num_heads=6,
  2438. qkv_bias=True,
  2439. init_values=1.0e-05, # layer-scale
  2440. rope_type='dinov3',
  2441. rope_temperature=100,
  2442. #rope_rescale_coords=2, # haven't added to interface
  2443. rope_rotate_half=True,
  2444. use_rot_pos_emb=True,
  2445. use_abs_pos_emb=False,
  2446. num_reg_tokens=4,
  2447. use_fc_norm=False,
  2448. norm_layer=partial(LayerNorm, eps=1e-5),
  2449. )
  2450. model = _create_eva('vit_small_patch16_dinov3_qkvb', pretrained=pretrained, **dict(model_args, **kwargs))
  2451. return model
  2452. @register_model
  2453. def vit_small_plus_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva:
  2454. """DINOv3 S/16 Plus https://arxiv.org/abs/2508.10104"""
  2455. model_args = dict(
  2456. patch_size=16,
  2457. dynamic_img_size=True,
  2458. embed_dim=384,
  2459. depth=12,
  2460. num_heads=6,
  2461. qkv_bias=False,
  2462. init_values=1.0e-05, # layer-scale
  2463. rope_type='dinov3',
  2464. rope_temperature=100,
  2465. #rope_rescale_coords=2, # haven't added to interface
  2466. rope_rotate_half=True,
  2467. use_rot_pos_emb=True,
  2468. use_abs_pos_emb=False,
  2469. swiglu_mlp=True,
  2470. swiglu_align_to=8,
  2471. num_reg_tokens=4,
  2472. use_fc_norm=False,
  2473. norm_layer=partial(LayerNorm, eps=1e-5),
  2474. )
  2475. model = _create_eva('vit_small_plus_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs))
  2476. return model
  2477. @register_model
  2478. def vit_small_plus_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva:
  2479. """DINOv3 S/16 Plus w/ QKV bias enabled (but 0) https://arxiv.org/abs/2508.10104"""
  2480. model_args = dict(
  2481. patch_size=16,
  2482. dynamic_img_size=True,
  2483. embed_dim=384,
  2484. depth=12,
  2485. num_heads=6,
  2486. qkv_bias=True,
  2487. init_values=1.0e-05, # layer-scale
  2488. rope_type='dinov3',
  2489. rope_temperature=100,
  2490. #rope_rescale_coords=2, # haven't added to interface
  2491. rope_rotate_half=True,
  2492. use_rot_pos_emb=True,
  2493. use_abs_pos_emb=False,
  2494. swiglu_mlp=True,
  2495. swiglu_align_to=8,
  2496. num_reg_tokens=4,
  2497. use_fc_norm=False,
  2498. norm_layer=partial(LayerNorm, eps=1e-5),
  2499. )
  2500. model = _create_eva('vit_small_plus_patch16_dinov3_qkvb', pretrained=pretrained, **dict(model_args, **kwargs))
  2501. return model
  2502. @register_model
  2503. def vit_base_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva:
  2504. """DINOv3 B/16 https://arxiv.org/abs/2508.10104"""
  2505. model_args = dict(
  2506. patch_size=16,
  2507. dynamic_img_size=True,
  2508. embed_dim=768,
  2509. depth=12,
  2510. num_heads=12,
  2511. qkv_bias=False,
  2512. init_values=1.0e-05, # layer-scale
  2513. rope_type='dinov3',
  2514. rope_temperature=100,
  2515. #rope_rescale_coords=2, # haven't added to interface
  2516. rope_rotate_half=True,
  2517. use_rot_pos_emb=True,
  2518. use_abs_pos_emb=False,
  2519. num_reg_tokens=4,
  2520. use_fc_norm=False,
  2521. norm_layer=partial(LayerNorm, eps=1e-5),
  2522. )
  2523. model = _create_eva('vit_base_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs))
  2524. return model
  2525. @register_model
  2526. def vit_base_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva:
  2527. """DINOv3 B/16 w/ QKV bias enabled (but zero) https://arxiv.org/abs/2508.10104"""
  2528. model_args = dict(
  2529. patch_size=16,
  2530. dynamic_img_size=True,
  2531. embed_dim=768,
  2532. depth=12,
  2533. num_heads=12,
  2534. qkv_bias=True,
  2535. init_values=1.0e-05, # layer-scale
  2536. rope_type='dinov3',
  2537. rope_temperature=100,
  2538. #rope_rescale_coords=2, # haven't added to interface
  2539. rope_rotate_half=True,
  2540. use_rot_pos_emb=True,
  2541. use_abs_pos_emb=False,
  2542. num_reg_tokens=4,
  2543. use_fc_norm=False,
  2544. norm_layer=partial(LayerNorm, eps=1e-5),
  2545. )
  2546. model = _create_eva('vit_base_patch16_dinov3_qkvb', pretrained=pretrained, **dict(model_args, **kwargs))
  2547. return model
  2548. @register_model
  2549. def vit_large_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva:
  2550. """DINOv3 L/16 https://arxiv.org/abs/2508.10104"""
  2551. model_args = dict(
  2552. patch_size=16,
  2553. dynamic_img_size=True,
  2554. embed_dim=1024,
  2555. depth=24,
  2556. num_heads=16,
  2557. qkv_bias=False,
  2558. init_values=1.0e-5, # layer-scale
  2559. rope_type='dinov3',
  2560. rope_temperature=100,
  2561. use_rot_pos_emb=True,
  2562. use_abs_pos_emb=False,
  2563. rope_rotate_half=True,
  2564. #rope_rescale_coords=2, # haven't added to interface
  2565. num_reg_tokens=4,
  2566. use_fc_norm=False,
  2567. norm_layer=partial(LayerNorm, eps=1e-5),
  2568. )
  2569. model = _create_eva('vit_large_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs))
  2570. return model
  2571. @register_model
  2572. def vit_large_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva:
  2573. """DINOv3 w/ QKV bias enabled (but zero) https://arxiv.org/abs/2508.10104"""
  2574. model_args = dict(
  2575. patch_size=16,
  2576. dynamic_img_size=True,
  2577. embed_dim=1024,
  2578. depth=24,
  2579. num_heads=16,
  2580. qkv_bias=True,
  2581. init_values=1.0e-5, # layer-scale
  2582. rope_type='dinov3',
  2583. rope_temperature=100,
  2584. use_rot_pos_emb=True,
  2585. use_abs_pos_emb=False,
  2586. rope_rotate_half=True,
  2587. #rope_rescale_coords=2, # haven't added to interface
  2588. num_reg_tokens=4,
  2589. use_fc_norm=False,
  2590. norm_layer=partial(LayerNorm, eps=1e-5),
  2591. )
  2592. model = _create_eva('vit_large_patch16_dinov3_qkvb', pretrained=pretrained, **dict(model_args, **kwargs))
  2593. return model
  2594. @register_model
  2595. def vit_huge_plus_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva:
  2596. """DINOv3 H/16 Plus https://arxiv.org/abs/2508.10104"""
  2597. model_args = dict(
  2598. patch_size=16,
  2599. dynamic_img_size=True,
  2600. embed_dim=1280,
  2601. depth=32,
  2602. num_heads=20,
  2603. qkv_bias=False,
  2604. init_values=1.0e-5, # layer-scale
  2605. rope_type='dinov3',
  2606. rope_temperature=100,
  2607. use_rot_pos_emb=True,
  2608. use_abs_pos_emb=False,
  2609. rope_rotate_half=True,
  2610. swiglu_mlp=True,
  2611. swiglu_align_to=8,
  2612. #rope_rescale_coords=2, # haven't added to interface
  2613. num_reg_tokens=4,
  2614. use_fc_norm=False,
  2615. norm_layer=partial(LayerNorm, eps=1e-5),
  2616. )
  2617. model = _create_eva('vit_huge_plus_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs))
  2618. return model
  2619. @register_model
  2620. def vit_huge_plus_patch16_dinov3_qkvb(pretrained: bool = False, **kwargs) -> Eva:
  2621. """DINOv3 H/16 Plus w/ QKV bias enabled (but zero) https://arxiv.org/abs/2508.10104"""
  2622. model_args = dict(
  2623. patch_size=16,
  2624. dynamic_img_size=True,
  2625. embed_dim=1280,
  2626. depth=32,
  2627. num_heads=20,
  2628. qkv_bias=True,
  2629. init_values=1.0e-5, # layer-scale
  2630. rope_type='dinov3',
  2631. rope_temperature=100,
  2632. use_rot_pos_emb=True,
  2633. use_abs_pos_emb=False,
  2634. rope_rotate_half=True,
  2635. swiglu_mlp=True,
  2636. swiglu_align_to=8,
  2637. #rope_rescale_coords=2, # haven't added to interface
  2638. num_reg_tokens=4,
  2639. use_fc_norm=False,
  2640. norm_layer=partial(LayerNorm, eps=1e-5),
  2641. )
  2642. model = _create_eva('vit_huge_plus_patch16_dinov3_qkvb', pretrained=pretrained, **dict(model_args, **kwargs))
  2643. return model
  2644. @register_model
  2645. def vit_7b_patch16_dinov3(pretrained: bool = False, **kwargs) -> Eva:
  2646. """DINOv3 7B/16 https://arxiv.org/abs/2508.10104"""
  2647. model_args = dict(
  2648. patch_size=16,
  2649. dynamic_img_size=True,
  2650. embed_dim=4096,
  2651. depth=40,
  2652. num_heads=32,
  2653. qkv_bias=False,
  2654. mlp_ratio=2,
  2655. init_values=1.0e-5, # layer-scale
  2656. rope_type='dinov3',
  2657. rope_temperature=100,
  2658. use_rot_pos_emb=True,
  2659. use_abs_pos_emb=False,
  2660. rope_rotate_half=True,
  2661. swiglu_mlp=True,
  2662. swiglu_align_to=64,
  2663. #rope_rescale_coords=2, # haven't added to interface
  2664. num_reg_tokens=4,
  2665. use_fc_norm=False,
  2666. norm_layer=partial(LayerNorm, eps=1e-5),
  2667. )
  2668. model = _create_eva('vit_7b_patch16_dinov3', pretrained=pretrained, **dict(model_args, **kwargs))
  2669. return model