naflexvit.py 91 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235
  1. """ NaFlex Vision Transformer
  2. An improved version of the Vision Transformer with:
  3. 1. Encapsulated embedding and position encoding in a single module
  4. 2. Support for linear patch embedding on pre-patchified inputs
  5. 3. Support for NaFlex variable aspect, variable resolution
  6. 4. Support for FlexiViT variable patch size
  7. 5. Support for NaViT fractional/factorized position embedding
  8. Based on ideas from:
  9. - Original Vision Transformer: https://arxiv.org/abs/2010.11929
  10. - FlexiViT: https://arxiv.org/abs/2212.08013
  11. - NaViT: https://arxiv.org/abs/2307.06304
  12. - NaFlex (SigLip-2): https://arxiv.org/abs/2502.14786
  13. Hacked together by / Copyright 2025, Ross Wightman, Hugging Face
  14. """
  15. import logging
  16. import math
  17. from dataclasses import dataclass, fields, replace
  18. from functools import partial
  19. from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union, Any
  20. import torch
  21. import torch.nn as nn
  22. import torch.nn.functional as F
  23. from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
  24. from timm.layers import (
  25. AttentionPoolLatent,
  26. Mlp,
  27. LayerNorm,
  28. PatchDropoutWithIndices,
  29. PatchEmbedInterpolator,
  30. _assert,
  31. to_2tuple,
  32. get_act_layer,
  33. get_norm_layer,
  34. apply_keep_indices_nlc,
  35. disable_compiler,
  36. calculate_drop_path_rates,
  37. )
  38. from ._builder import build_model_with_cfg
  39. from ._features import feature_take_indices
  40. from ._features_fx import register_notrace_function, register_notrace_module
  41. from ._manipulate import checkpoint, named_apply
  42. from ._registry import register_model, generate_default_cfgs
  43. from .eva import EvaBlock
  44. from .vision_transformer import Block, global_pool_nlc
  45. __all__ = ['NaFlexVitCfg', 'NaFlexVit']
  46. _logger = logging.getLogger(__name__)
  47. @dataclass
  48. class NaFlexVitCfg:
  49. """Configuration for FlexVit model.
  50. This dataclass contains the bulk of model configuration parameters,
  51. with core parameters (img_size, in_chans, num_classes, etc.) remaining
  52. as direct constructor arguments for API compatibility.
  53. """
  54. # Architecture parameters
  55. patch_size: Union[int, Tuple[int, int]] = 16
  56. embed_dim: int = 768
  57. depth: int = 12
  58. num_heads: int = 12
  59. mlp_ratio: float = 4.0
  60. scale_mlp_norm: bool = False # Apply scaling norm to MLP
  61. # Attention parameters
  62. qkv_bias: bool = True
  63. qk_norm: bool = False
  64. proj_bias: bool = True
  65. attn_drop_rate: float = 0.0
  66. scale_attn_inner_norm: bool = False # Apply scaling norm to attn context
  67. # Regularization
  68. init_values: Optional[float] = None # Layer-scale init values (layer-scale enabled if not None)
  69. drop_rate: float = 0.0 # Dropout rate for classifier
  70. pos_drop_rate: float = 0.0 # Dropout rate for position embeddings
  71. patch_drop_rate: float = 0.0 # Dropout rate for patch tokens
  72. proj_drop_rate: float = 0.0 # Dropout rate for linear projections
  73. drop_path_rate: float = 0.0 # Stochastic depth drop rate
  74. # Prefix token configuration
  75. class_token: bool = False # Use class token
  76. reg_tokens: int = 0 # Number of register tokens
  77. # Position embedding configuration
  78. pos_embed: str = 'learned' # Type of position embedding ('learned', 'factorized', 'rope', 'none')
  79. pos_embed_grid_size: Optional[Tuple[int, int]] = (16, 16) # Grid size for position embedding initialization
  80. pos_embed_interp_mode: str = 'bicubic' # Interpolation mode for position embedding resizing
  81. pos_embed_ar_preserving: bool = False # Whether to preserve aspect ratio during position embedding interpolation
  82. pos_embed_use_grid_sample: bool = False # Whether to use grid_sample for naflex position embedding interpolation
  83. # ROPE specific configuration
  84. rope_type: str = '' # ROPE type: '' or 'none' for no ROPE, 'axial' for standard, 'mixed' for learnable frequencies
  85. rope_temperature: float = 10000.0 # Temperature for ROPE frequency computation
  86. rope_ref_feat_shape: Optional[Tuple[int, int]] = None
  87. rope_grid_offset: float = 0. # Grid offset for non-pixel ROPE mode
  88. rope_grid_indexing: str = 'ij' # Grid indexing mode for ROPE ('ij' or 'xy')
  89. # Image processing
  90. dynamic_img_pad: bool = False # Whether to enable dynamic padding for variable resolution
  91. # Other architecture choices
  92. pre_norm: bool = False # Whether to apply normalization before attention/MLP layers (start of blocks)
  93. final_norm: bool = True # Whether to apply final normalization before pooling and classifier (end of blocks)
  94. fc_norm: Optional[bool] = None # Whether to normalize features before final classifier (after pooling)
  95. # Global pooling setup
  96. global_pool: str = 'map' # Type of global pooling for final sequence
  97. pool_include_prefix: bool = False # Whether to include class/register prefix tokens in global pooling
  98. attn_pool_num_heads: Optional[int] = None # Override num_heads for attention pool
  99. attn_pool_mlp_ratio: Optional[float] = None # Override mlp_ratio for attention pool
  100. # Weight initialization
  101. weight_init: str = '' # Weight initialization scheme
  102. fix_init: bool = True # Apply weight initialization fix (scaling w/ layer index)
  103. # Embedding configuration
  104. embed_proj_type: str = 'linear' # Type of embedding layer ('conv' or 'linear')
  105. input_norm_layer: Optional[str] = None # Normalization layer for embeddings input (before input projection)
  106. embed_norm_layer: Optional[str] = None # Normalization layer for embeddings (after input projection)
  107. # Layer implementations
  108. norm_layer: Optional[str] = None # Normalization layer for transformer blocks
  109. act_layer: Optional[str] = None # Activation layer for MLP blocks
  110. block_fn: Optional[str] = None # Transformer block implementation class name
  111. mlp_layer: Optional[str] = None # MLP implementation class name
  112. # EVA-specific parameters
  113. attn_type: str = 'standard' # Attention type: 'standard', 'eva', 'rope'
  114. swiglu_mlp: bool = False # Use SwiGLU MLP variant
  115. qkv_fused: bool = True # Whether to use fused QKV projections
  116. # Variable patch size support
  117. enable_patch_interpolator: bool = False # Enable dynamic patch size support
  118. def _overlay_kwargs(cfg: NaFlexVitCfg, **kwargs) -> NaFlexVitCfg:
  119. """Overlay kwargs onto config, replacing config values with provided kwargs."""
  120. # Only update fields that exist in the config
  121. config_fields = set(cfg.__dataclass_fields__.keys())
  122. config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
  123. if config_kwargs:
  124. cfg = replace(cfg, **config_kwargs)
  125. return cfg
  126. def batch_patchify(
  127. x: torch.Tensor,
  128. patch_size: Tuple[int, int],
  129. pad: bool = True,
  130. ) -> Tuple[torch.Tensor, Tuple[int, int]]:
  131. """Patchify a batch of images.
  132. Args:
  133. x: Input tensor of shape [B, C, H, W].
  134. patch_size: Patch dimensions (patch_h, patch_w).
  135. pad: Whether to pad images to be divisible by patch size.
  136. Returns:
  137. Tuple of (patches, grid_size) where patches has shape [B, N, P*P*C]
  138. and grid_size is (num_patches_h, num_patches_w).
  139. """
  140. B, C, H, W = x.shape
  141. ph, pw = patch_size
  142. # Ensure the image is divisible by patch size
  143. if pad and (H % ph != 0 or W % pw != 0):
  144. pad_h = (ph - H % ph) % ph
  145. pad_w = (pw - W % pw) % pw
  146. x = F.pad(x, (0, pad_w, 0, pad_h))
  147. nh, nw = H // ph, W // pw
  148. patches = x.view(B, C, nh, ph, nw, pw).permute(0, 2, 4, 3, 5, 1).reshape(B, nh * nw, ph * pw * C)
  149. # FIXME confirm we want 'channels last' in the patch channel layout, egg ph, ph, C instead of C, ph, hw
  150. return patches, (nh, nw)
  151. def calculate_naflex_grid_sizes(_coord: torch.Tensor):
  152. # Calculate the appropriate grid size from coords
  153. max_y = _coord[:, :, 0].amax(dim=1) + 1
  154. max_x = _coord[:, :, 1].amax(dim=1) + 1
  155. return [(int(h.item()), int(w.item())) for h, w in zip(max_y, max_x)]
  156. class NaFlexRopeIterator:
  157. """Iterator for generating batched ROPE embeddings for mixed mode with multiple grid sizes."""
  158. def __init__(
  159. self,
  160. rope_module,
  161. size_to_indices: Dict[Tuple[int, int], List[int]],
  162. unique_sizes: List[Tuple[int, int]],
  163. batch_size: int,
  164. seq_len: int,
  165. device: torch.device,
  166. dtype: torch.dtype,
  167. ):
  168. self.rope = rope_module
  169. self.size_to_indices = size_to_indices
  170. self.unique_sizes = unique_sizes
  171. self.batch_size = batch_size
  172. self.seq_len = seq_len
  173. self.dtype = dtype
  174. self.device = device
  175. self.depth = rope_module.depth
  176. self.num_heads = rope_module.num_heads
  177. self.head_dim = 2 * rope_module.dim // rope_module.num_heads
  178. self._depth_idx = 0
  179. # Pre-compute embeddings for each unique size
  180. self._embeddings_per_size = {}
  181. for grid_size in unique_sizes:
  182. # get_embed returns all depths at once for mixed mode
  183. rope_embed = rope_module.get_embed(shape=grid_size)
  184. self._embeddings_per_size[grid_size] = rope_embed
  185. def __iter__(self):
  186. self._depth_idx = 0
  187. return self
  188. @disable_compiler
  189. def __next__(self):
  190. if self._depth_idx >= self.depth:
  191. raise StopIteration
  192. # Create batch tensor for current depth
  193. batch_embed = torch.zeros(
  194. self.batch_size, self.num_heads, self.seq_len, self.head_dim,
  195. dtype=self.dtype, device=self.device
  196. )
  197. # Fill in embeddings for each unique grid size
  198. for grid_size in self.unique_sizes:
  199. h, w = grid_size
  200. actual_len = h * w
  201. batch_indices = self.size_to_indices[grid_size]
  202. # Get pre-computed embeddings for this size at current depth
  203. embed = self._embeddings_per_size[grid_size][self._depth_idx] # [num_heads, H*W, dim]
  204. # Assign to batch indices
  205. for bi in batch_indices:
  206. batch_embed[bi, :, :actual_len, :] = embed[:, :actual_len, :]
  207. self._depth_idx += 1
  208. return batch_embed
  209. def get_block_fn(cfg: NaFlexVitCfg) -> Callable:
  210. """Get appropriate block function based on configuration.
  211. Returns a partially applied block constructor with EVA-specific
  212. or conflicting parameters pre-configured if needed.
  213. """
  214. # Check if we need EVA block features
  215. use_eva_features = (
  216. cfg.attn_type in ('eva', 'rope') or
  217. cfg.rope_type not in ('', 'none') or # Any ROPE type requires EVA blocks
  218. cfg.swiglu_mlp
  219. )
  220. if use_eva_features:
  221. # Determine attention type based on rope_type if not explicitly set
  222. attn_type = cfg.attn_type
  223. if attn_type == 'standard' and cfg.rope_type not in ('', 'none'):
  224. attn_type = 'rope'
  225. num_prefix_tokens = (1 if cfg.class_token else 0) + cfg.reg_tokens
  226. return partial(
  227. EvaBlock,
  228. attn_type=attn_type,
  229. swiglu_mlp=cfg.swiglu_mlp,
  230. scale_mlp=cfg.scale_mlp_norm,
  231. scale_attn_inner=cfg.scale_attn_inner_norm,
  232. qkv_fused=cfg.qkv_fused,
  233. num_prefix_tokens=num_prefix_tokens,
  234. )
  235. else:
  236. # Standard ViT block
  237. block_fn = cfg.block_fn or Block
  238. if cfg.scale_mlp_norm or cfg.scale_attn_inner_norm:
  239. # param names differ between EVA vs non-EVA block types
  240. block_fn = partial(
  241. block_fn,
  242. scale_mlp_norm=cfg.scale_mlp_norm,
  243. scale_attn_norm=cfg.scale_attn_inner_norm
  244. )
  245. return block_fn
  246. @register_notrace_module
  247. class NaFlexEmbeds(nn.Module):
  248. """NaFlex Embedding module for Vision Transformers.
  249. This module encapsulates the complete embedding process for Vision Transformers,
  250. supporting both standard and NaFlex (NaViT + FlexiViT) functionality:
  251. 1. Patch embedding (via Conv2d or Linear)
  252. 2. Class and register token preparation
  253. 3. Position embedding addition with interpolation support
  254. 4. Pre-normalization (if requested)
  255. 5. Dropout application
  256. NaFlex capabilities include:
  257. - Variable aspect ratio and resolution via patch coordinates
  258. - Patch type indicators for handling padding tokens in attention
  259. - Flexible position embedding interpolation for arbitrary grid sizes
  260. - Support for factorized position embeddings
  261. The patch embedding can be one of two types:
  262. - Conv2d-based (default): For standard image inputs [B, C, H, W]
  263. - Linear-based: For pre-patchified inputs [B, N, P*P*C]
  264. Args:
  265. patch_size: Size of patches for patch embedding
  266. in_chans: Number of input image channels
  267. embed_dim: Dimensionality of patch embedding
  268. proj_type: Type of embedding projection layer ('conv' or 'linear')
  269. input_norm_layer: Normalization layer applied to input (linear mode only)
  270. proj_norm_layer: Normalization layer applied after projection
  271. pos_embed: Type of position embedding ('learned', 'factorized', 'none')
  272. pos_drop_rate: Dropout rate for position embeddings
  273. class_token: Whether to include a class token
  274. reg_tokens: Number of register tokens to include
  275. bias: Whether to use bias in projection layers
  276. dynamic_img_pad: Whether to enable dynamic padding for variable resolution
  277. pos_embed_grid_size: Grid size for position embedding initialization
  278. pos_embed_interp_mode: Interpolation mode for position embedding resizing
  279. pos_embed_ar_preserving: Whether to preserve aspect ratio during position embedding interpolation
  280. default_img_size: Default image size for position embedding grid calculation
  281. """
  282. def __init__(
  283. self,
  284. patch_size: Union[int, Tuple[int, int]] = 16,
  285. in_chans: int = 3,
  286. embed_dim: int = 768,
  287. proj_type: Optional[str] = None,
  288. proj_bias: bool = True,
  289. class_token: bool = True,
  290. reg_tokens: int = 0,
  291. dynamic_img_pad: bool = False,
  292. default_img_size: Optional[Union[int, Tuple[int, int]]] = None,
  293. pos_embed: str = 'learned',
  294. pos_embed_grid_size: Optional[Tuple[int, int]] = (14, 14),
  295. pos_embed_interp_mode: str = 'bicubic',
  296. pos_embed_ar_preserving: bool = False,
  297. pos_embed_use_grid_sample: bool = False,
  298. input_norm_layer: Optional[Type[nn.Module]] = None,
  299. proj_norm_layer: Union[bool, Optional[Type[nn.Module]]] = None,
  300. norm_layer: Optional[Type[nn.Module]] = None,
  301. pos_drop_rate: float = 0.,
  302. enable_patch_interpolator: bool = False,
  303. device=None,
  304. dtype=None,
  305. ) -> None:
  306. """Initialize NaFlexEmbeds module.
  307. Args:
  308. patch_size: Size of patches for patch embedding.
  309. in_chans: Number of input image channels.
  310. embed_dim: Dimensionality of patch embedding.
  311. proj_type: Type of embedding projection layer ('conv' or 'linear').
  312. proj_bias: Whether to use bias in projection layers.
  313. class_token: Whether to include a class token.
  314. reg_tokens: Number of register tokens to include.
  315. dynamic_img_pad: Whether to enable dynamic padding for variable resolution.
  316. default_img_size: Default image size for position embedding grid calculation.
  317. pos_embed: Type of position embedding ('learned', 'factorized', 'none').
  318. pos_embed_grid_size: Grid size for position embedding initialization.
  319. pos_embed_interp_mode: Interpolation mode for position embedding resizing.
  320. pos_embed_ar_preserving: Whether to preserve aspect ratio during interpolation.
  321. input_norm_layer: Normalization layer applied to input (linear mode only).
  322. proj_norm_layer: Normalization layer applied after projection.
  323. norm_layer: Default normalization layer.
  324. pos_drop_rate: Dropout rate for position embeddings.
  325. enable_patch_interpolator: Enable dynamic patch size support.
  326. """
  327. dd = {'device': device, 'dtype': dtype}
  328. super().__init__()
  329. self.has_class_token = class_token
  330. self.num_reg_tokens = reg_tokens
  331. self.pos_embed_interp_mode = pos_embed_interp_mode
  332. self.pos_embed_ar_preserving = pos_embed_ar_preserving
  333. self.pos_embed_use_grid_sample = pos_embed_use_grid_sample
  334. self.patch_size = to_2tuple(patch_size)
  335. self.in_chans = in_chans
  336. self.embed_dim = embed_dim
  337. self.dynamic_img_pad = dynamic_img_pad
  338. self.enable_patch_interpolator = enable_patch_interpolator
  339. # Calculate number of prefix tokens
  340. self.num_prefix_tokens = 1 if class_token else 0
  341. self.num_prefix_tokens += reg_tokens
  342. # Create class and register tokens
  343. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd)) if class_token else None
  344. self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim, **dd)) if reg_tokens else None
  345. # Calculate grid size and number of patches
  346. self.default_img_size: Optional[Tuple[int, int]] = None
  347. self.pos_embed_grid_size: Optional[Tuple[int, int]] = None # Grid size used for learned pos embed init
  348. if pos_embed_grid_size is not None:
  349. # Highest priority, use provided pos_embed_grid_size
  350. self.pos_embed_grid_size = pos_embed_grid_size
  351. elif default_img_size is not None:
  352. # Fallback to calculating grid size from img_size + patch_size if img size provided.
  353. self.default_img_size = to_2tuple(default_img_size)
  354. self.pos_embed_grid_size = tuple([s // p for s, p in zip(self.default_img_size, self.patch_size)])
  355. # Determine patch embedding type (linear or conv2d)
  356. if proj_type == 'linear':
  357. # Create linear projection for pre-patchified inputs
  358. # Input dimension is patch_size^2 * in_chans
  359. patch_dim = self.patch_size[0] * self.patch_size[1] * in_chans
  360. assert not (input_norm_layer is True and norm_layer is None), \
  361. "`norm_layer` must be given when input_norm_layer=True"
  362. input_norm_layer = norm_layer if input_norm_layer is True else (input_norm_layer or None)
  363. self.norm_input = input_norm_layer(patch_dim) if input_norm_layer else None
  364. self.proj = nn.Linear(patch_dim, embed_dim, bias=proj_bias, **dd)
  365. self.flatten = False
  366. self.is_linear = True
  367. else:
  368. # Default to convolutional patch embedding for image inputs
  369. assert not input_norm_layer
  370. self.norm_input = None
  371. self.proj = nn.Conv2d(
  372. in_chans,
  373. embed_dim,
  374. kernel_size=patch_size,
  375. stride=patch_size,
  376. bias=proj_bias,
  377. **dd,
  378. )
  379. self.flatten = True
  380. self.is_linear = False
  381. # Create patch embedding interpolator if enabled
  382. if self.enable_patch_interpolator:
  383. self.patch_interpolator = PatchEmbedInterpolator(
  384. base_patch_size=self.patch_size,
  385. in_chans=in_chans,
  386. embed_dim=embed_dim,
  387. interpolation=pos_embed_interp_mode,
  388. antialias=True,
  389. )
  390. else:
  391. self.patch_interpolator = None
  392. # Create normalization layer after the projection
  393. assert not (proj_norm_layer is True and norm_layer is None), \
  394. "`norm_layer` must be given when proj_norm_layer=True"
  395. proj_norm_layer = norm_layer if proj_norm_layer is True else (proj_norm_layer or None)
  396. self.norm = proj_norm_layer(embed_dim) if proj_norm_layer else nn.Identity()
  397. # Create position embedding if needed - only for patches, never for prefix tokens
  398. if pos_embed in ('factorized', 'learned') and self.pos_embed_grid_size is None:
  399. raise ValueError(
  400. "Cannot initialize position embeddings without grid_size."
  401. "Please provide img_size or pos_embed_grid_size.")
  402. self.pos_embed: Optional[torch.Tensor] = None
  403. self.pos_embed_y: Optional[torch.Tensor] = None
  404. self.pos_embed_x: Optional[torch.Tensor] = None
  405. if not pos_embed or pos_embed == 'none':
  406. self.pos_embed_type = 'none'
  407. elif pos_embed == 'factorized':
  408. assert self.pos_embed_grid_size is not None
  409. h, w = self.pos_embed_grid_size
  410. self.pos_embed_type = 'factorized'
  411. self.pos_embed_y = nn.Parameter(torch.randn(1, h, embed_dim, **dd) * .02)
  412. self.pos_embed_x = nn.Parameter(torch.randn(1, w, embed_dim, **dd) * .02)
  413. else:
  414. assert self.pos_embed_grid_size is not None
  415. h, w = self.pos_embed_grid_size
  416. self.pos_embed = nn.Parameter(torch.randn(1, h, w, embed_dim, **dd) * .02)
  417. self.pos_embed_type = 'learned'
  418. # Dropout layer
  419. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  420. def feature_info(self, location) -> Dict[str, Any]:
  421. """Get feature information for feature extraction.
  422. Args:
  423. location: Feature extraction location identifier
  424. Returns:
  425. Dictionary containing feature channel count and reduction factor
  426. """
  427. return dict(num_chs=self.embed_dim, reduction=self.patch_size)
  428. def feat_ratio(self, as_scalar: bool = True) -> Union[int, Tuple[int, int]]:
  429. """Get the feature reduction ratio (stride) of the patch embedding.
  430. Args:
  431. as_scalar: Whether to return the maximum dimension as a scalar
  432. Returns:
  433. Feature reduction ratio as scalar or tuple
  434. """
  435. if as_scalar:
  436. return max(self.patch_size)
  437. else:
  438. return self.patch_size
  439. def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
  440. """Calculate grid (feature) size for given image size.
  441. Takes into account dynamic padding when enabled.
  442. Args:
  443. img_size: Input image size as (height, width)
  444. Returns:
  445. Grid size as (grid_height, grid_width)
  446. """
  447. if self.dynamic_img_pad:
  448. return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1])
  449. else:
  450. return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
  451. @disable_compiler
  452. def _apply_learned_naflex_pos_embed(
  453. self,
  454. x: torch.Tensor,
  455. patch_coord: torch.Tensor,
  456. ) -> None:
  457. """Apply learned position embeddings to NaFlex batch in-place.
  458. Interpolates learned 2D position embeddings for each sample in the batch
  459. based on their individual grid sizes.
  460. Args:
  461. x: Input tensor to add position embeddings to [B, N, C]
  462. patch_coord: Patch coordinates [B, N, 2] with (y, x) values
  463. """
  464. # Calculate grid sizes from patch coordinates
  465. naflex_grid_sizes = calculate_naflex_grid_sizes(patch_coord)
  466. orig_h, orig_w = self.pos_embed.shape[1:3]
  467. pos_embed_nchw = self.pos_embed.permute(0, 3, 1, 2).float() # B,C,H,W
  468. def _interp2d(size):
  469. """
  470. Return a flattened positional-embedding grid at an arbitrary spatial resolution.
  471. Converts the learned 2-D table stored in NCHW format (pos_embed_nchw) into
  472. a (1, H*W, C) sequence that matches the requested size.
  473. """
  474. if (size[0] == orig_h) and (size[1] == orig_w):
  475. pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
  476. else:
  477. _interp_size = to_2tuple(max(size)) if self.pos_embed_ar_preserving else size
  478. pos_embed_flat = F.interpolate(
  479. pos_embed_nchw,
  480. size=_interp_size,
  481. mode=self.pos_embed_interp_mode,
  482. align_corners=False,
  483. antialias=True,
  484. )[:, :, :size[0], :size[1]].flatten(2).transpose(1, 2)
  485. return pos_embed_flat.to(dtype=x.dtype)
  486. # Determine unique grid sizes to avoid duplicate interpolation
  487. size_to_indices: Dict[Tuple[int, int], List[int]] = {}
  488. for bi, k in enumerate(naflex_grid_sizes):
  489. # k = h << 16 | w # FIXME can get jit compat with this
  490. size_to_indices.setdefault(k, []).append(bi)
  491. for k, batch_indices in size_to_indices.items():
  492. # h, w = k >> 16, k & 0xFFFF # FIXME can get jit compat with this
  493. # Interpolate only once for this (h, w)
  494. pos_embed_flat = _interp2d(k)
  495. seq_len = min(x.shape[1], pos_embed_flat.shape[1])
  496. x[:, :seq_len].index_add_(
  497. 0,
  498. torch.as_tensor(batch_indices, device=x.device),
  499. pos_embed_flat[:, :seq_len].expand(len(batch_indices), -1, -1)
  500. )
  501. @disable_compiler
  502. def _apply_learned_naflex_pos_embed_grid_sample(
  503. self,
  504. x: torch.Tensor,
  505. patch_coord: torch.Tensor,
  506. ) -> None:
  507. """Apply learned position embeddings to NaFlex batch using grid_sample.
  508. Uses F.grid_sample for efficient interpolation of learned 2D position embeddings
  509. based on patch coordinates. Based on proposal by https://github.com/stas-sl
  510. Args:
  511. x: Input tensor to add position embeddings to [B, N, C]
  512. patch_coord: Patch coordinates [B, N, 2] with (y, x) values
  513. """
  514. device = x.device
  515. B, N, C = x.shape
  516. shapes = patch_coord.max(dim=1).values + 1 # (B, 2) containing [h_i, w_i]
  517. if self.pos_embed_ar_preserving:
  518. L_i = shapes.amax(dim=1) # (B,) max(h_i, w_i)
  519. L_global = L_i.amax()
  520. grid_size_y = grid_size_x = L_global
  521. scale_x = scale_y = L_global / L_i # uniform zoom (B,)
  522. else:
  523. grid_size_y, grid_size_x = shapes.amax(dim=0) # (2,)
  524. scale_y = grid_size_y / shapes[:, 0] # vertical zoom (B,)
  525. scale_x = grid_size_x / shapes[:, 1] # horizontal zoom (B,)
  526. theta = torch.zeros(B, 2, 3, device=device, dtype=torch.float32)
  527. theta[:, 0, 0] = scale_x
  528. theta[:, 1, 1] = scale_y
  529. theta[:, 0, 2] = scale_x - 1 # translate x
  530. theta[:, 1, 2] = scale_y - 1 # translate y
  531. grid = F.affine_grid(theta, (B, C, grid_size_y, grid_size_x), align_corners=False)
  532. pos_embed = F.grid_sample(
  533. self.pos_embed.permute(0, 3, 1, 2).expand(B, -1, -1, -1).float(),
  534. grid,
  535. mode=self.pos_embed_interp_mode,
  536. align_corners=False,
  537. padding_mode='border',
  538. ).to(dtype=x.dtype) # (B, C, H_out, W_out)
  539. bi = torch.arange(B, device=device, dtype=torch.long).unsqueeze(1)
  540. x += pos_embed[bi, :, patch_coord[..., 0], patch_coord[..., 1]] # NOTE leave as '+='
  541. def _apply_learned_pos_embed(
  542. self,
  543. x: torch.Tensor,
  544. grid_size: List[int],
  545. ) -> None:
  546. """Apply learned position embeddings to standard 2D batch in-place.
  547. Interpolates learned 2D position embeddings to match the specified grid size.
  548. Args:
  549. x: Input tensor to add position embeddings to [B, H*W, C]
  550. grid_size: Target grid size as [height, width]
  551. """
  552. orig_h, orig_w = self.pos_embed.shape[1:3]
  553. if grid_size[0] == orig_h and grid_size[1] == orig_w:
  554. # No resize needed, just flatten
  555. pos_embed_flat = self.pos_embed.reshape(1, orig_h * orig_w, -1)
  556. else:
  557. # Resize if needed - directly using F.interpolate
  558. if self.pos_embed_ar_preserving:
  559. L = max(grid_size)
  560. _interp_size = L, L
  561. else:
  562. _interp_size = grid_size
  563. pos_embed_flat = F.interpolate(
  564. self.pos_embed.permute(0, 3, 1, 2).float(), # B,C,H,W
  565. size=_interp_size,
  566. mode=self.pos_embed_interp_mode,
  567. align_corners=False,
  568. antialias=True,
  569. )[:, :, :grid_size[0], :grid_size[1]].flatten(2).transpose(1, 2)
  570. pos_embed_flat = pos_embed_flat.to(dtype=x.dtype)
  571. x.add_(pos_embed_flat)
  572. @disable_compiler
  573. def _apply_factorized_naflex_pos_embed(
  574. self,
  575. x: torch.Tensor,
  576. patch_coord: torch.Tensor,
  577. ) -> None:
  578. """Apply factorized position embeddings to NaFlex batch in-place.
  579. Uses separate Y and X position embedding tables that are interpolated
  580. and combined for each sample's grid size.
  581. Args:
  582. x: Input tensor to add position embeddings to [B, N, C]
  583. patch_coord: Patch coordinates [B, N, 2] with (y, x) values
  584. """
  585. # Calculate grid sizes from patch coordinates
  586. naflex_grid_sizes = calculate_naflex_grid_sizes(patch_coord)
  587. assert len(naflex_grid_sizes) == x.size(0) # one (H,W) per sample
  588. # Handle each batch element separately with its own grid size
  589. orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1]
  590. # bucket samples that share the same (H, W) so we build each grid once
  591. size_to_indices: Dict[Tuple[int, int], List[int]] = {}
  592. for bi, k in enumerate(naflex_grid_sizes):
  593. size_to_indices.setdefault(k, []).append(bi)
  594. def _interp1d(table: torch.Tensor, new_length: int, orig_length: int) -> torch.Tensor:
  595. """
  596. Resample a 1-D positional-embedding table to specified length
  597. and return it in (1, L, C) layout, dtype matching x.
  598. """
  599. if new_length == orig_length:
  600. return table.to(dtype=x.dtype)
  601. return F.interpolate(
  602. table.permute(0, 2, 1).float(), # (1,C,L) → (1,C,L_out)
  603. size=new_length,
  604. mode='linear',
  605. align_corners=False,
  606. ).permute(0, 2, 1).to(dtype=x.dtype) # → (1,L_out,C)
  607. for k, batch_indices in size_to_indices.items():
  608. target_h, target_w = k
  609. if self.pos_embed_ar_preserving:
  610. len_y = len_x = max(target_h, target_w)
  611. else:
  612. len_y, len_x = target_h, target_w
  613. pe_y = _interp1d(self.pos_embed_y, len_y, orig_h)[:, :target_h] # (1,H,C)
  614. pe_x = _interp1d(self.pos_embed_x, len_x, orig_w)[:, :target_w] # (1,W,C)
  615. # Broadcast, add and flatten to sequence layout (row major)
  616. pos = pe_y.unsqueeze(2) + pe_x.unsqueeze(1) # (1,H,W,C)
  617. pos = pos.flatten(1, 2)
  618. seq_len = min(x.shape[1], pos.shape[1])
  619. x[:, :seq_len].index_add_(
  620. 0,
  621. torch.as_tensor(batch_indices, device=x.device),
  622. pos[:, :seq_len].expand(len(batch_indices), -1, -1)
  623. )
  624. @disable_compiler
  625. def _apply_factorized_naflex_pos_embed_grid_sample(
  626. self,
  627. x: torch.Tensor,
  628. patch_coord: torch.Tensor,
  629. ) -> None:
  630. """Apply factorized position embeddings to NaFlex batch using grid_sample.
  631. Uses F.grid_sample for efficient interpolation of separate Y and X position
  632. embedding tables based on patch coordinates. Based on proposal by https://github.com/stas-sl
  633. Args:
  634. x: Input tensor to add position embeddings to [B, N, C]
  635. patch_coord: Patch coordinates [B, N, 2] with (y, x) values
  636. """
  637. device = x.device
  638. B, _, C = x.shape
  639. shapes = patch_coord.amax(dim=1) + 1
  640. if self.pos_embed_ar_preserving:
  641. # Aspect ratio preserving mode: use square grid with uniform scaling
  642. L_i = shapes.amax(dim=1) # (B,) max(h_i, w_i)
  643. L_global = L_i.amax()
  644. grid_size_y = grid_size_x = L_global
  645. scale_x = scale_y = L_global / L_i # uniform zoom (B,)
  646. else:
  647. # Standard mode: different scaling for x and y
  648. grid_size_y, grid_size_x = shapes.amax(0)
  649. scale_x = grid_size_x / shapes[:, 1] # horizontal zoom (B,)
  650. scale_y = grid_size_y / shapes[:, 0] # vertical zoom (B,)
  651. def _interp1d(table: torch.Tensor, scale: torch.Tensor, out_length: torch.Tensor) -> torch.Tensor:
  652. pe = table.permute(0, 2, 1).unsqueeze(2).expand(B, -1, -1, -1).float() # (1, L, C) -> (B, C, 1, L)
  653. theta = torch.zeros(B, 2, 3, device=x.device)
  654. theta[:, 0, 0] = scale
  655. theta[:, 0, 2] = scale - 1
  656. theta[:, 1, 1] = 1
  657. grid = F.affine_grid(theta, (B, C, 1, out_length), align_corners=False)
  658. pe = F.grid_sample(pe, grid, mode='bilinear', align_corners=False, padding_mode='border')
  659. return pe.to(x.dtype)
  660. # Interpolate along each axis
  661. pe_x = _interp1d(self.pos_embed_x, scale=scale_x, out_length=grid_size_x)
  662. pe_y = _interp1d(self.pos_embed_y, scale=scale_y, out_length=grid_size_y)
  663. bi = torch.arange(B, device=device, dtype=torch.long).unsqueeze(1)
  664. x += pe_x[bi, :, 0, patch_coord[..., 1]] + pe_y[bi, :, 0, patch_coord[..., 0]]
  665. def _apply_factorized_pos_embed(
  666. self,
  667. x: torch.Tensor,
  668. grid_size: List[int],
  669. ) -> None:
  670. """Apply factorized position embeddings to standard 2D batch in-place.
  671. Uses separate Y and X position embedding tables that are interpolated
  672. and combined for the specified grid size.
  673. Args:
  674. x: Input tensor to add position embeddings to [B, H*W, C]
  675. grid_size: Target grid size as [height, width]
  676. """
  677. orig_h, orig_w = self.pos_embed_y.shape[1], self.pos_embed_x.shape[1]
  678. target_h, target_w = grid_size
  679. if self.pos_embed_ar_preserving:
  680. len_y = len_x = max(target_h, target_w)
  681. else:
  682. len_y, len_x = target_h, target_w
  683. def _interp1d(table: torch.Tensor, new_length: int, orig_length: int) -> torch.Tensor:
  684. if new_length == orig_length:
  685. return table.to(dtype=x.dtype)
  686. return F.interpolate(
  687. table.permute(0, 2, 1).float(), # (1,L,C) -> (1,C,L)
  688. size=new_length,
  689. mode='linear',
  690. align_corners=False,
  691. ).permute(0, 2, 1).to(dtype=x.dtype) # (1,L,C)
  692. # Interpolate embeddings
  693. pe_y = _interp1d(self.pos_embed_y, len_y, orig_h)[:, :target_h] # (1,H,C)
  694. pe_x = _interp1d(self.pos_embed_x, len_x, orig_w)[:, :target_w] # (1,W,C)
  695. # Broadcast, add and flatten to sequence layout (row major)
  696. pos_embed = pe_y.unsqueeze(2) + pe_x.unsqueeze(1) # (1, H, W, C)
  697. pos_embed_flat = pos_embed.flatten(1, 2) # (1, H*W, C)
  698. x.add_(pos_embed_flat)
  699. def forward(
  700. self,
  701. x: torch.Tensor,
  702. patch_coord: Optional[torch.Tensor] = None,
  703. patch_valid: Optional[torch.Tensor] = None,
  704. ) -> Tuple[torch.Tensor, Optional[Tuple[int, int]]]:
  705. """Forward pass for patch embedding with position encoding.
  706. Args:
  707. x: Input tensor. Supported formats:
  708. - [B, C, H, W] for conv mode
  709. - [B, N, P*P*C] for pre-patchified linear mode (normal)
  710. - [B, N, Ph, Pw, C] for pre-patchified linear mode (variable patch size)
  711. patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode.
  712. patch_valid: Optional validity mask for patches [B, N] for NaFlex mode.
  713. Returns:
  714. Tuple of (embedded_tensor, grid_size) where:
  715. - embedded_tensor: [B, num_prefix_tokens + N, embed_dim]
  716. - grid_size: (H, W) tuple for standard mode, None for NaFlex mode
  717. """
  718. grid_size: Optional[Tuple[int, int]] = None
  719. B = x.shape[0]
  720. if self.is_linear:
  721. # Linear embedding path, works with NaFlex mode or standard 2D mode
  722. if patch_coord is None:
  723. # Standard 2D (B, C, H, W) mode
  724. _assert(x.ndim == 4, 'Expecting 2D image input with input ndim == 4')
  725. x, grid_size = batch_patchify(x, self.patch_size, pad=self.dynamic_img_pad)
  726. else:
  727. # Pre-patchified NaFlex mode
  728. # Variable patch size mode: [B, N, Ph, Pw, C], normal mode: [B, N, P*P*C]
  729. _assert(x.ndim == 5 or x.ndim == 3, 'Expecting patchified input with ndim == 3 or 5.')
  730. # Handle variable patch size projection
  731. if self.enable_patch_interpolator and x.ndim == 5:
  732. _assert(self.norm_input is None, 'input norm not supported with patch resizing')
  733. # Apply projection with interpolation
  734. x = self.patch_interpolator(
  735. x,
  736. self.proj.weight,
  737. self.proj.bias,
  738. patch_size=tuple(x.shape[2:4]), # patch size from [B, N, Ph, Pw, C] shape
  739. is_linear=True,
  740. )
  741. else:
  742. # Standard projection
  743. x = x.flatten(2) # ensure [B, N, P*P*C], flatten Ph*Pw*C if separate
  744. if self.norm_input is not None:
  745. x = self.norm_input(x)
  746. x = self.proj(x)
  747. else:
  748. _assert(x.ndim == 4, 'Convolutional input must be 4D')
  749. if self.dynamic_img_pad:
  750. H, W = x.shape[-2:]
  751. pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
  752. pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
  753. x = F.pad(x, (0, pad_w, 0, pad_h))
  754. x = self.proj(x)
  755. grid_size = x.shape[-2:]
  756. if self.flatten:
  757. x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
  758. # Apply normalization after flattening
  759. x = self.norm(x)
  760. if self.pos_embed_type == 'learned':
  761. if grid_size is not None:
  762. # Standard 2D mode
  763. self._apply_learned_pos_embed(x, grid_size=grid_size)
  764. else:
  765. # NaFlex mode
  766. if self.pos_embed_use_grid_sample:
  767. self._apply_learned_naflex_pos_embed_grid_sample(x, patch_coord=patch_coord)
  768. else:
  769. self._apply_learned_naflex_pos_embed(x, patch_coord=patch_coord)
  770. elif self.pos_embed_type == 'factorized':
  771. if grid_size is not None:
  772. # Standard 2D mode
  773. self._apply_factorized_pos_embed(x, grid_size=grid_size)
  774. else:
  775. # NaFlex mode
  776. if self.pos_embed_use_grid_sample:
  777. self._apply_factorized_naflex_pos_embed_grid_sample(x, patch_coord=patch_coord)
  778. else:
  779. self._apply_factorized_naflex_pos_embed(x, patch_coord=patch_coord)
  780. # Prepare and add class and register tokens
  781. to_cat = []
  782. if self.cls_token is not None:
  783. to_cat.append(self.cls_token.expand(B, -1, -1))
  784. if self.reg_token is not None:
  785. to_cat.append(self.reg_token.expand(B, -1, -1))
  786. # Add tokens to the beginning
  787. if to_cat:
  788. x = torch.cat(to_cat + [x], dim=1)
  789. # Apply dropout
  790. x = self.pos_drop(x)
  791. return x, grid_size
  792. @register_notrace_function
  793. def create_attention_mask(
  794. patch_valid: torch.Tensor,
  795. num_prefix_tokens: int = 0,
  796. symmetric: bool = True,
  797. q_len: Optional[int] = None,
  798. dtype: torch.dtype = torch.float32,
  799. ) -> Optional[torch.Tensor]:
  800. """Creates an attention mask from patch validity information.
  801. Supports two modes controlled by `symmetric`:
  802. 1. `symmetric=True` (default): Creates a symmetric mask of shape
  803. [B, 1, seq_len, seq_len]. An attention pair (i, j) is allowed only if
  804. both token i and token j are valid. Suitable for standard self-attention.
  805. 2. `symmetric=False`: Creates a potentially non-square mask of shape
  806. [B, 1, q_len, kv_len]. An attention pair (q, k) is allowed only if
  807. the key/value token k is valid. Query token validity is not checked
  808. in the mask itself. Useful for cross-attention or specific self-attention
  809. implementations `q_len` can be specified.
  810. Used for NaFlex mode to handle variable token counts and padding tokens.
  811. Args:
  812. patch_valid: Tensor of shape [B, N] with True for valid patches, False for padding.
  813. num_prefix_tokens: Number of prefix tokens (class token, register tokens)
  814. to prepend, which are always considered valid.
  815. symmetric: If True, create a symmetric mask.
  816. If False, create an expanded mask based only on key/value validity.
  817. q_len: Query sequence length override. Only used when `symmetric` is False.
  818. Defaults to the key/value sequence length (`kv_len`) if None.
  819. dtype: Dtype of the output attention mask (e.g., torch.float32).
  820. Returns:
  821. Attention mask tensor. Additive mask (-inf for masked, 0 for unmasked).
  822. Shape is [B, 1, seq_len, seq_len] if symmetric=True,
  823. or [B, 1, q_len, kv_len] if symmetric=False.
  824. """
  825. if patch_valid is None:
  826. return None
  827. patch_valid = patch_valid.bool() # Ensure boolean type
  828. B, N = patch_valid.shape
  829. kv_len = N # Initial key/value length is the number of patches
  830. # Prepend prefix tokens if any
  831. if num_prefix_tokens > 0:
  832. # Create prefix validity tensor on the same device/dtype base as patch_valid
  833. prefix_valid = patch_valid.new_ones((B, num_prefix_tokens), dtype=torch.bool)
  834. # Concatenate prefix and patch validity. Shape becomes [B, num_prefix_tokens + N]
  835. patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
  836. kv_len += num_prefix_tokens # Update total key/value sequence length
  837. if symmetric:
  838. # Symmetric mask is True where BOTH query and key are valid
  839. mask_bool = patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)
  840. mask_bool = mask_bool.unsqueeze(1) # Add head dimension: [B, 1, seq_len, seq_len]
  841. else:
  842. # Expanded mask
  843. q_len = q_len or kv_len
  844. mask_bool = patch_valid[:, None, None, :].expand(B, 1, q_len, kv_len)
  845. # Create the float mask and apply masking using additive mask convention
  846. mask_float = torch.zeros_like(mask_bool, dtype=dtype)
  847. # Fill with negative infinity where mask_bool is False (masked positions)
  848. mask_float.masked_fill_(~mask_bool, torch.finfo(dtype).min)
  849. return mask_float
  850. @register_notrace_function
  851. def global_pool_naflex(
  852. x: torch.Tensor,
  853. patch_valid: Optional[torch.Tensor] = None,
  854. pool_type: str = 'token',
  855. num_prefix_tokens: int = 1,
  856. reduce_include_prefix: bool = False,
  857. ) -> torch.Tensor:
  858. """Global pooling with NaFlex support for masked tokens.
  859. Applies global pooling while respecting patch validity masks to exclude
  860. padding tokens from pooling operations.
  861. Args:
  862. x: Input tensor with shape [B, N, C]
  863. patch_valid: Optional validity mask for patches [B, N-num_prefix_tokens]
  864. pool_type: Type of pooling ('token', 'avg', 'avgmax', 'max')
  865. num_prefix_tokens: Number of prefix tokens (class/register)
  866. reduce_include_prefix: Whether to include prefix tokens in pooling reduction
  867. Returns:
  868. Pooled tensor with shape [B, C]
  869. """
  870. if patch_valid is None or pool_type not in ('avg', 'avgmax', 'max'):
  871. # Fall back to standard pooling
  872. x = global_pool_nlc(
  873. x,
  874. pool_type=pool_type,
  875. num_prefix_tokens=num_prefix_tokens,
  876. reduce_include_prefix=reduce_include_prefix,
  877. )
  878. return x
  879. # For NaFlex mode, we need to apply masked pooling to exclude padding tokens
  880. if num_prefix_tokens > 0:
  881. if reduce_include_prefix:
  882. # Include prefix tokens in pooling - they are always considered valid
  883. # patch_valid only covers patch tokens, so create combined validity mask
  884. prefix_valid = patch_valid.new_ones(x.shape[0], num_prefix_tokens)
  885. patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
  886. else:
  887. # Exclude prefix tokens from pooling (default behavior)
  888. x = x[:, num_prefix_tokens:]
  889. patch_valid_float = patch_valid.to(x.dtype)
  890. if pool_type == 'avg':
  891. # Compute masked average pooling, sum valid tokens and divide by count of valid tokens
  892. masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
  893. valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
  894. pooled = masked_sums / valid_counts
  895. return pooled
  896. elif pool_type == 'avgmax':
  897. # For avgmax, compute masked average and masked max
  898. masked_sums = (x * patch_valid_float.unsqueeze(-1)).sum(dim=1)
  899. valid_counts = patch_valid_float.sum(dim=1, keepdim=True).clamp(min=1)
  900. masked_avg = masked_sums / valid_counts
  901. # For max pooling we set masked positions to large negative value
  902. masked_x = x.clone()
  903. masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
  904. masked_max = masked_x.amax(dim=1)
  905. # Combine average and max
  906. return 0.5 * (masked_avg + masked_max)
  907. elif pool_type == 'max':
  908. # For max pooling we set masked positions to large negative value
  909. masked_x = x.clone()
  910. masked_x[~patch_valid] = torch.finfo(masked_x.dtype).min
  911. return masked_x.amax(dim=1)
  912. else:
  913. assert False
  914. class NaFlexVit(nn.Module):
  915. """NaFlexVit: Vision Transformer with NaFlex support for flexible input handling.
  916. A flexible implementation of Vision Transformer that supports:
  917. - Standard image classification with various pooling strategies
  918. - NaFlex functionality for variable aspect ratios and resolutions
  919. - Linear patch embedding for pre-patchified inputs
  920. - Multiple position embedding strategies (learned, factorized, rope)
  921. - Comprehensive attention masking for efficient batch processing
  922. - Encapsulated embedding and position encoding in FlexEmbeds module
  923. - Compatible with standard ViT checkpoints through checkpoint filtering
  924. """
  925. def __init__(
  926. self,
  927. cfg: Optional[NaFlexVitCfg] = None,
  928. in_chans: int = 3,
  929. num_classes: int = 1000,
  930. img_size: Optional[Union[int, Tuple[int, int]]] = None,
  931. device=None,
  932. dtype=None,
  933. **kwargs,
  934. ) -> None:
  935. """Initialize NaFlexVit model.
  936. Args:
  937. cfg: Model configuration. If None, uses default NaFlexVitCfg.
  938. in_chans: Number of input image channels.
  939. num_classes: Number of classification classes.
  940. img_size: Input image size (for backwards compatibility with classic vit).
  941. **kwargs: Additional config parameters to override cfg values.
  942. """
  943. super().__init__()
  944. dd = {'device': device, 'dtype': dtype}
  945. # Initialize config
  946. cfg = cfg or NaFlexVitCfg()
  947. if kwargs:
  948. cfg = _overlay_kwargs(cfg, **kwargs)
  949. # Validate configuration
  950. assert cfg.global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
  951. assert cfg.class_token or cfg.global_pool != 'token'
  952. assert cfg.pos_embed in ('', 'none', 'learned', 'factorized')
  953. # Resolve layer implementations
  954. norm_layer = get_norm_layer(cfg.norm_layer) or LayerNorm
  955. embed_norm_layer = get_norm_layer(cfg.embed_norm_layer)
  956. act_layer = get_act_layer(cfg.act_layer) or nn.GELU
  957. block_fn = get_block_fn(cfg)
  958. mlp_layer = cfg.mlp_layer or Mlp # TODO: Support configurable mlp_layer via string lookup
  959. # Store instance variables
  960. self.num_classes = num_classes
  961. self.global_pool = cfg.global_pool
  962. self.num_features = self.head_hidden_size = self.embed_dim = cfg.embed_dim # for consistency with other models
  963. self.num_prefix_tokens = 1 if cfg.class_token else 0
  964. self.num_prefix_tokens += cfg.reg_tokens
  965. self.num_reg_tokens = cfg.reg_tokens
  966. self.has_class_token = cfg.class_token
  967. self.pool_include_prefix = cfg.pool_include_prefix
  968. self.grad_checkpointing = False
  969. # Initialize embedding module (includes patch, position embedding, and class/reg tokens)
  970. # FlexEmbeds is always used - handles both linear and conv embedding
  971. self.embeds = NaFlexEmbeds(
  972. patch_size=cfg.patch_size,
  973. in_chans=in_chans,
  974. embed_dim=cfg.embed_dim,
  975. proj_type=cfg.embed_proj_type,
  976. proj_bias=not cfg.pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
  977. class_token=cfg.class_token,
  978. reg_tokens=cfg.reg_tokens,
  979. default_img_size=img_size,
  980. dynamic_img_pad=cfg.dynamic_img_pad,
  981. pos_embed=cfg.pos_embed,
  982. pos_embed_grid_size=cfg.pos_embed_grid_size,
  983. pos_embed_interp_mode=cfg.pos_embed_interp_mode,
  984. pos_embed_ar_preserving=cfg.pos_embed_ar_preserving,
  985. pos_embed_use_grid_sample=cfg.pos_embed_use_grid_sample,
  986. proj_norm_layer=embed_norm_layer,
  987. pos_drop_rate=cfg.pos_drop_rate,
  988. enable_patch_interpolator=getattr(cfg, 'enable_patch_interpolator', False),
  989. **dd,
  990. )
  991. self.norm_pre = norm_layer(cfg.embed_dim, **dd) if cfg.pre_norm else nn.Identity()
  992. # ROPE position embeddings at model level
  993. self.rope: Optional[nn.Module] = None
  994. self.rope_is_mixed = False
  995. if cfg.rope_type and cfg.rope_type != 'none':
  996. from timm.layers.pos_embed_sincos import RotaryEmbeddingCat, RotaryEmbeddingMixed
  997. if cfg.rope_type == 'mixed':
  998. self.rope = RotaryEmbeddingMixed(
  999. cfg.embed_dim,
  1000. depth=cfg.depth,
  1001. num_heads=cfg.num_heads,
  1002. temperature=cfg.rope_temperature,
  1003. feat_shape=None, # Dynamic shapes for NaFlex
  1004. grid_indexing=cfg.rope_grid_indexing,
  1005. **dd,
  1006. )
  1007. self.rope_is_mixed = True
  1008. elif cfg.rope_type == 'axial':
  1009. self.rope = RotaryEmbeddingCat(
  1010. cfg.embed_dim // cfg.num_heads,
  1011. temperature=cfg.rope_temperature,
  1012. in_pixels=False,
  1013. feat_shape=None, # Dynamic shapes for NaFlex
  1014. ref_feat_shape=cfg.rope_ref_feat_shape,
  1015. grid_offset=cfg.rope_grid_offset,
  1016. grid_indexing=cfg.rope_grid_indexing,
  1017. **dd,
  1018. )
  1019. self.rope_is_mixed = False
  1020. else:
  1021. raise ValueError(f"Unknown rope_type: {cfg.rope_type}")
  1022. # Patch dropout
  1023. if cfg.patch_drop_rate > 0:
  1024. self.patch_drop = PatchDropoutWithIndices(
  1025. cfg.patch_drop_rate,
  1026. num_prefix_tokens=self.num_prefix_tokens,
  1027. )
  1028. else:
  1029. self.patch_drop = None
  1030. # Transformer blocks
  1031. dpr = calculate_drop_path_rates(cfg.drop_path_rate, cfg.depth) # stochastic depth decay rule
  1032. # Create transformer blocks
  1033. self.blocks = nn.Sequential(*[
  1034. block_fn(
  1035. dim=cfg.embed_dim,
  1036. num_heads=cfg.num_heads,
  1037. mlp_ratio=cfg.mlp_ratio,
  1038. qkv_bias=cfg.qkv_bias,
  1039. qk_norm=cfg.qk_norm,
  1040. proj_bias=cfg.proj_bias,
  1041. init_values=cfg.init_values,
  1042. proj_drop=cfg.proj_drop_rate,
  1043. attn_drop=cfg.attn_drop_rate,
  1044. drop_path=dpr[i],
  1045. norm_layer=norm_layer,
  1046. act_layer=act_layer,
  1047. mlp_layer=mlp_layer,
  1048. **dd,
  1049. )
  1050. for i in range(cfg.depth)
  1051. ])
  1052. # Feature info for downstream tasks
  1053. patch_reduction = self.embeds.feat_ratio(as_scalar=True)
  1054. self.feature_info = [
  1055. dict(module=f'blocks.{i}', num_chs=cfg.embed_dim, reduction=patch_reduction)
  1056. for i in range(cfg.depth)
  1057. ]
  1058. self.norm = norm_layer(cfg.embed_dim, **dd) if cfg.final_norm and not cfg.fc_norm else nn.Identity()
  1059. # Classifier Head
  1060. if cfg.global_pool == 'map':
  1061. self.attn_pool = AttentionPoolLatent(
  1062. self.embed_dim,
  1063. num_heads=cfg.attn_pool_num_heads or cfg.num_heads,
  1064. mlp_ratio=cfg.attn_pool_mlp_ratio or cfg.mlp_ratio,
  1065. norm_layer=norm_layer,
  1066. act_layer=act_layer,
  1067. **dd,
  1068. )
  1069. else:
  1070. self.attn_pool = None
  1071. # Handle fc_norm default value
  1072. fc_norm = cfg.fc_norm
  1073. if fc_norm is None:
  1074. fc_norm = cfg.global_pool == 'avg'
  1075. self.fc_norm = norm_layer(cfg.embed_dim, **dd) if cfg.final_norm and fc_norm else nn.Identity()
  1076. self.head_drop = nn.Dropout(cfg.drop_rate)
  1077. self.head = nn.Linear(self.embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  1078. if cfg.weight_init != 'skip':
  1079. self.init_weights(cfg.weight_init)
  1080. if cfg.fix_init:
  1081. self.fix_init_weight()
  1082. def fix_init_weight(self) -> None:
  1083. """Apply initialization weight fix with layer-wise scaling."""
  1084. def rescale(param: torch.Tensor, _layer_id: int) -> None:
  1085. with torch.no_grad():
  1086. param.div_(math.sqrt(2.0 * _layer_id))
  1087. for layer_id, layer in enumerate(self.blocks):
  1088. if hasattr(layer, 'attn'):
  1089. rescale(layer.attn.proj.weight, layer_id + 1)
  1090. if hasattr(layer, 'mlp'):
  1091. rescale(layer.mlp.fc2.weight, layer_id + 1)
  1092. if hasattr(layer, 'attn_out_proj'):
  1093. rescale(layer.attn_out_proj.weight, layer_id + 1)
  1094. if hasattr(layer, 'mlp_out_proj'):
  1095. rescale(layer.mlp_out_proj.weight, layer_id + 1)
  1096. def init_weights(self, mode: str = '') -> None:
  1097. """Initialize model weights according to specified scheme.
  1098. Args:
  1099. mode: Initialization mode ('jax', 'jax_nlhb', 'moco', or '')
  1100. """
  1101. assert mode in ('jax', 'jax_nlhb', 'moco', '')
  1102. head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
  1103. named_apply(get_init_weights_vit(mode, head_bias), self)
  1104. @torch.jit.ignore()
  1105. def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None:
  1106. # Custom loading for the new model structure
  1107. from .vision_transformer import _load_weights as _orig_load_weights
  1108. def _load_weights_adapter(model, checkpoint_path, prefix=''):
  1109. """Adapter function to handle the different model structure"""
  1110. state_dict = torch.load(checkpoint_path, map_location='cpu')
  1111. if isinstance(state_dict, dict) and 'state_dict' in state_dict:
  1112. state_dict = state_dict['state_dict']
  1113. # Map original keys to new structure
  1114. for k in list(state_dict.keys()):
  1115. if k.startswith('cls_token'):
  1116. state_dict['embeds.' + k] = state_dict.pop(k)
  1117. elif k.startswith('reg_token'):
  1118. state_dict['embeds.' + k] = state_dict.pop(k)
  1119. elif k.startswith('pos_embed'):
  1120. state_dict['embeds.' + k] = state_dict.pop(k)
  1121. elif k.startswith('patch_embed'):
  1122. state_dict['embeds.' + k[12:]] = state_dict.pop(k)
  1123. return _orig_load_weights(model, state_dict, prefix)
  1124. _load_weights_adapter(self, checkpoint_path, prefix)
  1125. @torch.jit.ignore
  1126. def no_weight_decay(self) -> Set:
  1127. """Get set of parameter names that should not have weight decay applied.
  1128. Returns:
  1129. Set of parameter names to skip during weight decay
  1130. """
  1131. skip_list = {'embeds.pos_embed', 'embeds.cls_token', 'embeds.reg_token'}
  1132. if self.rope and hasattr(self.rope, 'no_weight_decay'):
  1133. skip_list.update(self.rope.no_weight_decay())
  1134. return skip_list
  1135. @torch.jit.ignore
  1136. def group_matcher(self, coarse: bool = False) -> Dict:
  1137. """Get parameter group matcher for optimizer parameter grouping.
  1138. Args:
  1139. coarse: Whether to use coarse-grained grouping
  1140. Returns:
  1141. Dictionary mapping group names to regex patterns
  1142. """
  1143. return dict(
  1144. stem=r'^embeds', # stem and embed
  1145. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
  1146. )
  1147. @torch.jit.ignore
  1148. def set_grad_checkpointing(self, enable: bool = True) -> None:
  1149. """Enable or disable gradient checkpointing for memory efficiency.
  1150. Args:
  1151. enable: Whether to enable gradient checkpointing
  1152. """
  1153. self.grad_checkpointing = enable
  1154. if hasattr(self.embeds, 'patch_embed') and hasattr(self.embeds.patch_embed, 'set_grad_checkpointing'):
  1155. self.embeds.patch_embed.set_grad_checkpointing(enable)
  1156. @torch.jit.ignore
  1157. def get_classifier(self) -> nn.Module:
  1158. """Get the classification head module.
  1159. Returns:
  1160. Classification head module
  1161. """
  1162. return self.head
  1163. @disable_compiler
  1164. def _generate_rope_naflex(
  1165. self,
  1166. x: torch.Tensor,
  1167. patch_coord: torch.Tensor,
  1168. ) -> Union[torch.Tensor, List[torch.Tensor], Any]:
  1169. """Generate ROPE position embeddings for NaFlex batch with variable grid sizes.
  1170. Args:
  1171. x: Input tensor [B, N, C]
  1172. patch_coord: Patch coordinates [B, N, 2] with (y, x) values
  1173. Returns:
  1174. ROPE embeddings:
  1175. - Axial mode: Tensor of shape [B, 1, N, dim*2]
  1176. - Mixed mode: List of tensors, each of shape [B, num_heads, N, dim], one per depth layer
  1177. - Mixed mode with iterator: Iterator yielding tensors per depth
  1178. """
  1179. # Calculate grid sizes for each sample
  1180. naflex_grid_sizes = calculate_naflex_grid_sizes(patch_coord)
  1181. # Build ROPE embeddings for each unique grid size
  1182. size_to_indices = {}
  1183. unique_sizes = []
  1184. for bi, grid_size in enumerate(naflex_grid_sizes):
  1185. if grid_size not in size_to_indices:
  1186. size_to_indices[grid_size] = []
  1187. unique_sizes.append(grid_size)
  1188. size_to_indices[grid_size].append(bi)
  1189. B, N, C = x.shape
  1190. seq_len = N - self.num_prefix_tokens
  1191. if self.rope_is_mixed:
  1192. # Use an iterator for Mixed mode, returns [batch_size, depth, num_heads, seq_len, dim]
  1193. return NaFlexRopeIterator(
  1194. self.rope,
  1195. size_to_indices,
  1196. unique_sizes,
  1197. B,
  1198. seq_len,
  1199. x.dtype,
  1200. x.device
  1201. )
  1202. # Axial mode: [batch_size, seq_len, dim*2]
  1203. rope_embeds = torch.zeros(B, seq_len, self.rope.dim * 2, dtype=x.dtype, device=x.device)
  1204. if hasattr(self.rope, 'get_batch_embeds'):
  1205. # Batch mode - generate unique embeds from one grid and then assign
  1206. unique_embeds = self.rope.get_batch_embeds(unique_sizes)
  1207. for grid_size, embed, batch_indices in zip(unique_sizes, unique_embeds, size_to_indices.values()):
  1208. h, w = grid_size
  1209. actual_len = h * w
  1210. for bi in batch_indices:
  1211. rope_embeds[bi, :actual_len] = embed[:actual_len]
  1212. else:
  1213. # Generate each unique size separately and assign
  1214. for grid_size, bi in size_to_indices.items():
  1215. rope_embed = self.rope.get_embed(shape=grid_size)
  1216. h, w = grid_size
  1217. actual_len = h * w
  1218. rope_embeds[bi, :actual_len] = rope_embed[:actual_len]
  1219. rope_embeds = rope_embeds.unsqueeze(1)
  1220. return rope_embeds
  1221. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  1222. """Reset the classification head with new number of classes and pooling.
  1223. Args:
  1224. num_classes: Number of classes for new classification head
  1225. global_pool: Optional new global pooling type
  1226. """
  1227. self.num_classes = num_classes
  1228. if global_pool is not None:
  1229. assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
  1230. if global_pool == 'map' and self.attn_pool is None:
  1231. assert False, "Cannot currently add attention pooling in reset_classifier()."
  1232. elif global_pool != 'map' and self.attn_pool is not None:
  1233. self.attn_pool = None # remove attention pooling
  1234. self.global_pool = global_pool
  1235. self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  1236. def _forward_embeds(
  1237. self,
  1238. x,
  1239. patch_coord,
  1240. patch_valid,
  1241. attn_mask,
  1242. ) -> Dict[str, torch.Tensor]:
  1243. """ Forward pass through patch / abs pos / rope pos embeds and patch dropout
  1244. """
  1245. naflex_mode = patch_coord is not None
  1246. # patch embed, abs pos embed, returns global grid size as calculated from 'standard' NCHW batches
  1247. x, grid_size = self.embeds(
  1248. x,
  1249. patch_coord=patch_coord,
  1250. patch_valid=patch_valid,
  1251. )
  1252. # Generate ROPE embeddings at model level
  1253. rope_embeds = None
  1254. if self.rope is not None:
  1255. if patch_coord is not None:
  1256. # NaFlex mode - variable grid sizes
  1257. rope_embeds = self._generate_rope_naflex(x, patch_coord)
  1258. elif grid_size is not None:
  1259. # Standard mode - fixed grid size
  1260. rope_embeds = self.rope.get_embed(shape=grid_size)
  1261. else:
  1262. assert False, 'Expected one of patch_coord or grid_size to be valid'
  1263. # Apply patch dropout with coordinated updates
  1264. keep_indices: Optional[torch.Tensor] = None
  1265. if self.training and self.patch_drop is not None:
  1266. x, keep_indices = self.patch_drop(x)
  1267. # keep_indices excludes prefix tokens, can use directly on patch_valid & rope embeds
  1268. if patch_valid is not None:
  1269. patch_valid = patch_valid.gather(1, keep_indices)
  1270. if rope_embeds is not None and not self.rope_is_mixed:
  1271. # Update ROPE embeddings to match dropped tokens (only for axial mode)
  1272. # Batch dim already present in NaFlex mode, but will be added in standard mode.
  1273. rope_embeds = apply_keep_indices_nlc(x, rope_embeds, keep_indices, pos_embed_has_batch=naflex_mode)
  1274. if not naflex_mode:
  1275. # B, N, dim -> B, 1, N, dim. Need head dim added for standard mode, already added in NaFlex.
  1276. rope_embeds = rope_embeds.unsqueeze(1)
  1277. # Create attention mask from patch_valid after patch dropout applied
  1278. if attn_mask is None:
  1279. attn_mask = create_attention_mask(
  1280. patch_valid,
  1281. num_prefix_tokens=self.num_prefix_tokens,
  1282. dtype=x.dtype
  1283. )
  1284. x = self.norm_pre(x)
  1285. return {
  1286. 'patches': x,
  1287. 'patch_valid': patch_valid,
  1288. 'rope_embeds': rope_embeds,
  1289. 'attn_mask': attn_mask,
  1290. 'keep_indices': keep_indices,
  1291. }
  1292. def forward_intermediates(
  1293. self,
  1294. x: Union[torch.Tensor, Dict[str, torch.Tensor]],
  1295. indices: Optional[Union[int, List[int]]] = None,
  1296. return_prefix_tokens: bool = False,
  1297. norm: bool = False,
  1298. stop_early: bool = False,
  1299. output_fmt: str = 'NCHW',
  1300. intermediates_only: bool = False,
  1301. output_dict: bool = False,
  1302. patch_coord: Optional[torch.Tensor] = None,
  1303. patch_valid: Optional[torch.Tensor] = None,
  1304. attn_mask: Optional[torch.Tensor] = None,
  1305. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]], Dict[str, Any]]:
  1306. """ Forward features that returns intermediates.
  1307. Args:
  1308. x: Input image tensor
  1309. indices: Take last n blocks if int, all if None, select matching indices if sequence
  1310. return_prefix_tokens: Return both prefix and spatial intermediate tokens
  1311. norm: Apply norm layer to all intermediates
  1312. stop_early: Stop iterating over blocks when last desired intermediate hit
  1313. output_fmt: Shape of intermediate feature outputs
  1314. intermediates_only: Only return intermediate features
  1315. output_dict: Return outputs as a dictionary with 'image_features' and 'image_intermediates' keys
  1316. patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode
  1317. patch_valid: Optional patch type indicators (1=patch, 0=padding) for NaFlex
  1318. attn_mask: Optional attention mask for masked attention
  1319. Returns:
  1320. A tuple with (final_features, intermediates), a list of intermediate features, or a dictionary containing
  1321. 'image_features' and 'image_intermediates' (and optionally 'image_intermediates_prefix')
  1322. """
  1323. # FIXME unfinished / untested
  1324. assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
  1325. reshape = output_fmt == 'NCHW'
  1326. intermediates = []
  1327. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  1328. if isinstance(x, Dict):
  1329. # Handle dictionary input from NaFlex collator
  1330. patch_coord = x['patch_coord']
  1331. patch_valid = x['patch_valid']
  1332. patches = x['patches']
  1333. assert False, 'WIP, patch mode needs more work'
  1334. else:
  1335. patches = x
  1336. height, width = x.shape[-2:]
  1337. H, W = self.embeds.dynamic_feat_size((height, width))
  1338. # Forward pass through patch and abs position embedding
  1339. embeds = self._forward_embeds(
  1340. patches,
  1341. patch_coord=patch_coord,
  1342. patch_valid=patch_valid,
  1343. attn_mask=attn_mask,
  1344. )
  1345. x = embeds['patches']
  1346. rope_embeds = embeds.get('rope_embeds', None)
  1347. keep_indices = embeds.get('keep_indices', None)
  1348. attn_mask = embeds.get('attn_mask', None)
  1349. # Forward pass through blocks
  1350. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  1351. blocks = self.blocks
  1352. else:
  1353. blocks = self.blocks[:max_index + 1]
  1354. do_checkpointing = self.grad_checkpointing and not torch.jit.is_scripting()
  1355. if self.rope_is_mixed and rope_embeds is not None:
  1356. # Mixed mode with per-layer embeddings (list or iterator)
  1357. for i, (blk, rope_embed) in enumerate(zip(self.blocks, rope_embeds)):
  1358. # Apply patch dropout to rope_embed if needed
  1359. if self.training and self.patch_drop is not None and keep_indices is not None:
  1360. # Apply patch dropout to rope_embed if needed (batch dim already present in naflex mode)
  1361. rope_embed = apply_keep_indices_nlc(
  1362. x,
  1363. rope_embed,
  1364. keep_indices,
  1365. pos_embed_has_batch=embeds.get('naflex_mode', False),
  1366. )
  1367. if do_checkpointing:
  1368. x = checkpoint(blk, x, rope=rope_embed, attn_mask=attn_mask)
  1369. else:
  1370. x = blk(x, rope=rope_embed, attn_mask=attn_mask)
  1371. if i in take_indices:
  1372. # normalize intermediates with final norm layer if enabled
  1373. intermediates.append(self.norm(x) if norm else x)
  1374. else:
  1375. for i, blk in enumerate(blocks):
  1376. # Axial ROPE mode with shared embeddings
  1377. if rope_embeds is not None:
  1378. if do_checkpointing:
  1379. x = checkpoint(blk, x, rope=rope_embeds, attn_mask=attn_mask)
  1380. else:
  1381. x = blk(x, rope=rope_embeds, attn_mask=attn_mask)
  1382. else:
  1383. if do_checkpointing:
  1384. x = checkpoint(blk, x, attn_mask=attn_mask)
  1385. else:
  1386. x = blk(x, attn_mask=attn_mask)
  1387. if i in take_indices:
  1388. # normalize intermediates with final norm layer if enabled
  1389. intermediates.append(self.norm(x) if norm else x)
  1390. # Process intermediates
  1391. if self.num_prefix_tokens:
  1392. # split prefix (e.g. class, distill) and spatial feature tokens
  1393. prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
  1394. intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
  1395. else:
  1396. prefix_tokens = None
  1397. if reshape:
  1398. # reshape to BCHW output format
  1399. intermediates = [
  1400. y.reshape(y.shape[0], H, W, -1).permute(0, 3, 1, 2).contiguous()
  1401. for y in intermediates
  1402. ]
  1403. # FIXME always use dict for NaFlex mode to return masks and more?
  1404. # For dictionary output
  1405. if output_dict:
  1406. result_dict = {}
  1407. # Intermediates are always included
  1408. result_dict['image_intermediates'] = intermediates
  1409. if prefix_tokens is not None and return_prefix_tokens:
  1410. result_dict['image_intermediates_prefix'] = prefix_tokens
  1411. # Only include features if not intermediates_only
  1412. if not intermediates_only:
  1413. x_final = self.norm(x)
  1414. result_dict['image_features'] = x_final
  1415. return result_dict
  1416. # For non-dictionary output, maintain the original behavior
  1417. if not torch.jit.is_scripting() and return_prefix_tokens and prefix_tokens is not None:
  1418. # return_prefix not support in torchscript due to poor type handling
  1419. intermediates = list(zip(intermediates, prefix_tokens))
  1420. if intermediates_only:
  1421. return intermediates
  1422. x = self.norm(x)
  1423. return x, intermediates
  1424. def forward_features(
  1425. self,
  1426. patches: torch.Tensor,
  1427. patch_coord: Optional[torch.Tensor] = None,
  1428. patch_valid: Optional[torch.Tensor] = None,
  1429. attn_mask: Optional[torch.Tensor] = None,
  1430. ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
  1431. """
  1432. """
  1433. naflex_mode = patch_coord is not None
  1434. # Pass through patch & abs position embedding module with patch coordinate/type support
  1435. embeds = self._forward_embeds(
  1436. patches,
  1437. patch_coord=patch_coord,
  1438. patch_valid=patch_valid,
  1439. attn_mask=attn_mask,
  1440. )
  1441. x = embeds['patches']
  1442. rope_embeds = embeds.get('rope_embeds', None)
  1443. keep_indices = embeds.get('keep_indices', None)
  1444. attn_mask = embeds.get('attn_mask', None)
  1445. # Apply transformer blocks with masked attention and/or ROPE if provided
  1446. do_checkpointing = self.grad_checkpointing and not torch.jit.is_scripting()
  1447. if self.rope_is_mixed and rope_embeds is not None:
  1448. # Mixed mode with per-layer embeddings (list or iterator)
  1449. for i, (blk, rope_embed) in enumerate(zip(self.blocks, rope_embeds)):
  1450. if self.training and self.patch_drop is not None and keep_indices is not None:
  1451. # Apply patch dropout to rope_embed if needed (batch dim already present in naflex mode)
  1452. rope_embed = apply_keep_indices_nlc(
  1453. x,
  1454. rope_embed,
  1455. keep_indices,
  1456. pos_embed_has_batch=naflex_mode,
  1457. )
  1458. if do_checkpointing:
  1459. x = checkpoint(blk, x, rope=rope_embed, attn_mask=attn_mask)
  1460. else:
  1461. x = blk(x, rope=rope_embed, attn_mask=attn_mask)
  1462. elif rope_embeds is not None:
  1463. # Axial ROPE mode with shared embeddings
  1464. for blk in self.blocks:
  1465. if do_checkpointing:
  1466. x = checkpoint(blk, x, rope=rope_embeds, attn_mask=attn_mask)
  1467. else:
  1468. x = blk(x, rope=rope_embeds, attn_mask=attn_mask)
  1469. else:
  1470. for blk in self.blocks:
  1471. if do_checkpointing:
  1472. x = checkpoint(blk, x, attn_mask=attn_mask)
  1473. else:
  1474. x = blk(x, attn_mask=attn_mask)
  1475. x = self.norm(x)
  1476. if naflex_mode:
  1477. return {
  1478. 'patches': x,
  1479. 'patch_valid': embeds.get('patch_valid', None),
  1480. }
  1481. return x
  1482. def _pool(
  1483. self,
  1484. x: torch.Tensor,
  1485. pool_type: Optional[str] = None,
  1486. patch_valid: Optional[torch.Tensor] = None,
  1487. ) -> torch.Tensor:
  1488. if self.attn_pool is not None:
  1489. attn_mask = create_attention_mask(
  1490. patch_valid,
  1491. num_prefix_tokens=self.num_prefix_tokens if self.pool_include_prefix else 0,
  1492. symmetric=False,
  1493. q_len=1,
  1494. dtype=x.dtype,
  1495. )
  1496. if not self.pool_include_prefix:
  1497. x = x[:, self.num_prefix_tokens:]
  1498. x = self.attn_pool(x, attn_mask=attn_mask)
  1499. return x
  1500. pool_type = self.global_pool if pool_type is None else pool_type
  1501. x = global_pool_naflex(
  1502. x,
  1503. patch_valid,
  1504. pool_type=pool_type,
  1505. num_prefix_tokens=self.num_prefix_tokens,
  1506. reduce_include_prefix=self.pool_include_prefix,
  1507. )
  1508. return x
  1509. def forward_head(
  1510. self,
  1511. patches: torch.Tensor,
  1512. pre_logits: bool = False,
  1513. patch_valid: Optional[torch.Tensor] = None,
  1514. ) -> torch.Tensor:
  1515. x = self._pool(patches, patch_valid=patch_valid)
  1516. x = self.fc_norm(x)
  1517. x = self.head_drop(x)
  1518. return x if pre_logits else self.head(x)
  1519. def forward(
  1520. self,
  1521. x: Union[torch.Tensor, Dict[str, torch.Tensor]],
  1522. patch_coord: Optional[torch.Tensor] = None,
  1523. patch_valid: Optional[torch.Tensor] = None,
  1524. attn_mask: Optional[torch.Tensor] = None,
  1525. ) -> torch.Tensor:
  1526. """Forward pass with optional NaFlex support.
  1527. Args:
  1528. x: Input tensor. Supported formats:
  1529. - [B, C, H, W] standard image input
  1530. - [B, N, P*P*C] pre-patchified tensor (flattened patches)
  1531. - [B, N, Ph, Pw, C] pre-patchified tensor (variable patch size)
  1532. - Dict from NaFlex collator
  1533. patch_coord: Optional patch coordinates [B, N, 2] for NaFlex mode.
  1534. patch_valid: Optional patch validity indicators for NaFlex.
  1535. attn_mask: Optional attn mask to override defaults generated from patch_valid
  1536. Returns:
  1537. Model output tensor.
  1538. """
  1539. input_is_dict = isinstance(x, Dict)
  1540. naflex_mode = input_is_dict or patch_coord is not None
  1541. if naflex_mode:
  1542. if input_is_dict:
  1543. # Handle dictionary input from NaFlex collator, dict inputs take priority over args
  1544. patches = x['patches']
  1545. patch_valid = x.get('patch_valid', patch_valid)
  1546. patch_coord = x.get('patch_coord', patch_coord)
  1547. attn_mask = x.get('attn_mask', attn_mask)
  1548. else:
  1549. patches = x
  1550. _assert(patch_coord is not None, "patch_coord is required in naflex mode")
  1551. _assert(patch_valid is not None, "patch_valid is required in naflex mode")
  1552. features = self.forward_features(
  1553. patches=patches,
  1554. patch_valid=patch_valid,
  1555. patch_coord=patch_coord,
  1556. attn_mask=attn_mask,
  1557. )
  1558. # Pass patches & patch_valid to forward_head for masked pooling
  1559. x = self.forward_head(**features)
  1560. else:
  1561. x = self.forward_features(x)
  1562. x = self.forward_head(x)
  1563. return x
  1564. def _debug_dump_patches(x):
  1565. # DEBUG, reconstruct patches & save
  1566. patch_coord = x['patch_coord']
  1567. patch_valid = x['patch_valid']
  1568. patches = x['patches']
  1569. for i in range(len(patches)):
  1570. patch = patches[i][patch_valid[i]]
  1571. h = (patch_coord[i, :, 0].max() + 1).item()
  1572. w = (patch_coord[i, :, 1].max() + 1).item()
  1573. patch = patch.reshape(h, w, 16, 16, 3).permute(4, 0, 2, 1, 3)
  1574. patch = patch.reshape(3, h*16, w*16)
  1575. from torchvision.utils import save_image
  1576. save_image(patch, f'patch_{i}.jpg', normalize=True)
  1577. def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> Callable:
  1578. """Function imported from vision_transformer.py to maintain compatibility"""
  1579. from .vision_transformer import init_weights_vit_jax, init_weights_vit_moco, init_weights_vit_timm
  1580. if 'jax' in mode:
  1581. return partial(init_weights_vit_jax, head_bias=head_bias)
  1582. elif 'moco' in mode:
  1583. return init_weights_vit_moco
  1584. else:
  1585. return init_weights_vit_timm
  1586. def checkpoint_filter_fn(state_dict: Dict[str, Any], model: NaFlexVit) -> Dict[str, Any]:
  1587. """Handle state dict conversion from original ViT to the new version with combined embedding."""
  1588. # Handle CombinedEmbed module pattern
  1589. out_dict = {}
  1590. for k, v in state_dict.items():
  1591. # Convert tokens and embeddings to combined_embed structure
  1592. if k == 'pos_embed':
  1593. # Handle position embedding format conversion - from (1, N, C) to (1, H, W, C)
  1594. if hasattr(model.embeds, 'pos_embed') and v.ndim == 3:
  1595. num_cls_token = 0
  1596. num_reg_token = 0
  1597. if 'reg_token' in state_dict:
  1598. num_reg_token = state_dict['reg_token'].shape[1]
  1599. if 'cls_token' in state_dict:
  1600. num_cls_token = state_dict['cls_token'].shape[1]
  1601. num_prefix_tokens = num_cls_token + num_reg_token
  1602. # Original format is (1, N, C), need to reshape to (1, H, W, C)
  1603. num_patches = v.shape[1]
  1604. num_patches_no_prefix = num_patches - num_prefix_tokens
  1605. grid_size_no_prefix = math.sqrt(num_patches_no_prefix)
  1606. grid_size = math.sqrt(num_patches)
  1607. if (grid_size_no_prefix != grid_size
  1608. and (grid_size_no_prefix.is_integer() and not grid_size.is_integer())
  1609. ):
  1610. # make a decision, did the pos_embed of the original include the prefix tokens?
  1611. num_patches = num_patches_no_prefix
  1612. cls_token_emb = v[:, 0:num_cls_token]
  1613. if cls_token_emb.numel():
  1614. state_dict['cls_token'] += cls_token_emb
  1615. reg_token_emb = v[:, num_cls_token:num_reg_token]
  1616. if reg_token_emb.numel():
  1617. state_dict['reg_token'] += reg_token_emb
  1618. v = v[:, num_prefix_tokens:]
  1619. grid_size = grid_size_no_prefix
  1620. grid_size = int(grid_size)
  1621. # Check if it's a perfect square for a standard grid
  1622. if grid_size * grid_size == num_patches:
  1623. # Reshape from (1, N, C) to (1, H, W, C)
  1624. v = v.reshape(1, grid_size, grid_size, v.shape[2])
  1625. else:
  1626. # Not a square grid, we need to get the actual dimensions
  1627. if hasattr(model.embeds.patch_embed, 'grid_size'):
  1628. h, w = model.embeds.patch_embed.grid_size
  1629. if h * w == num_patches:
  1630. # We have the right dimensions
  1631. v = v.reshape(1, h, w, v.shape[2])
  1632. else:
  1633. # Dimensions don't match, use interpolation
  1634. _logger.warning(
  1635. f"Position embedding size mismatch: checkpoint={num_patches}, model={(h * w)}. "
  1636. f"Using default initialization and will resize in forward pass."
  1637. )
  1638. # Keep v as is, the forward pass will handle resizing
  1639. out_dict['embeds.pos_embed'] = v
  1640. elif k == 'cls_token':
  1641. out_dict['embeds.cls_token'] = v
  1642. elif k == 'reg_token':
  1643. out_dict['embeds.reg_token'] = v
  1644. # Convert patch_embed.X to embeds.patch_embed.X
  1645. elif k.startswith('patch_embed.'):
  1646. suffix = k[12:]
  1647. if suffix == 'proj.weight':
  1648. v = v.permute(0, 2, 3, 1).flatten(1)
  1649. new_key = 'embeds.' + suffix
  1650. out_dict[new_key] = v
  1651. else:
  1652. out_dict[k] = v
  1653. return out_dict
  1654. def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  1655. return {
  1656. 'url': url,
  1657. 'num_classes': 1000,
  1658. 'input_size': (3, 384, 384),
  1659. 'pool_size': None,
  1660. 'crop_pct': 1.0,
  1661. 'interpolation': 'bicubic',
  1662. 'mean': IMAGENET_INCEPTION_MEAN,
  1663. 'std': IMAGENET_INCEPTION_STD,
  1664. 'first_conv': 'embeds.proj',
  1665. 'classifier': 'head',
  1666. 'license': 'apache-2.0',
  1667. **kwargs,
  1668. }
  1669. default_cfgs = generate_default_cfgs({
  1670. 'naflexvit_base_patch16_gap.e300_s576_in1k': _cfg(
  1671. hf_hub_id='timm/',
  1672. ),
  1673. 'naflexvit_base_patch16_par_gap.e300_s576_in1k': _cfg(
  1674. hf_hub_id='timm/',
  1675. ),
  1676. 'naflexvit_base_patch16_parfac_gap.e300_s576_in1k': _cfg(
  1677. hf_hub_id='timm/',
  1678. ),
  1679. 'naflexvit_base_patch16_map.untrained': _cfg(),
  1680. 'naflexvit_so150m2_patch16_reg1_gap.untrained': _cfg(),
  1681. 'naflexvit_so150m2_patch16_reg1_map.untrained': _cfg(),
  1682. # SigLIP-2 NaFlex vit encoder weights
  1683. 'naflexvit_base_patch16_siglip.v2_webli': _cfg(
  1684. hf_hub_id='timm/',
  1685. num_classes=0),
  1686. 'naflexvit_so400m_patch16_siglip.v2_webli': _cfg(
  1687. hf_hub_id='timm/',
  1688. num_classes=0),
  1689. })
  1690. def _create_naflexvit(variant: str, pretrained: bool = False, **kwargs) -> NaFlexVit:
  1691. out_indices = kwargs.pop('out_indices', 3)
  1692. cfg = kwargs.pop('cfg', NaFlexVitCfg())
  1693. cfg_field_names = {f.name for f in fields(NaFlexVitCfg)}
  1694. # pop in-place so the original kwargs is emptied of cfg-specific keys
  1695. cfg_updates = {k: kwargs.pop(k) for k in list(kwargs) if k in cfg_field_names}
  1696. if cfg_updates:
  1697. cfg = _overlay_kwargs(cfg, **cfg_updates)
  1698. model = build_model_with_cfg(
  1699. NaFlexVit, variant, pretrained,
  1700. pretrained_filter_fn=checkpoint_filter_fn,
  1701. cfg=cfg,
  1702. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  1703. **kwargs,
  1704. )
  1705. return model
  1706. def _create_naflexvit_from_classic(
  1707. variant: str,
  1708. pretrained: bool = False,
  1709. **kwargs,
  1710. ) -> NaFlexVit:
  1711. """Create FlexVit model from classic VisionTransformer configuration.
  1712. This function handles the parameter mapping and configuration logic needed
  1713. to create FlexVit models that are compatible with classic VisionTransformer
  1714. configurations and pretrained weights.
  1715. Args:
  1716. variant: Model variant name
  1717. pretrained: Whether to load pretrained weights
  1718. **kwargs: Classic VisionTransformer parameters
  1719. Returns:
  1720. FlexVit model instance
  1721. """
  1722. # Remove VisionTransformer-specific parameters that don't apply to FlexVit
  1723. kwargs.pop('no_embed_class', None)
  1724. kwargs.pop('dynamic_img_size', None)
  1725. # Handle global pooling and fc_norm defaults that differ between ViT and FlexVit
  1726. gp = kwargs.pop('global_pool', 'token') # Original ViTs default to cls token pooling
  1727. fc_norm = kwargs.pop('fc_norm', None) # Original ViTs used fc_norm when not set and avg pooling used
  1728. if fc_norm is None and gp == 'avg':
  1729. fc_norm = True
  1730. # Set FlexVit-specific defaults that differ from VisionTransformer
  1731. flex_kwargs = {
  1732. 'pos_embed_grid_size': None, # rely on img_size (// patch_size) that will be passed through
  1733. 'class_token': kwargs.get('class_token', True),
  1734. 'global_pool': gp,
  1735. 'fc_norm': fc_norm,
  1736. 'scale_mlp_norm': kwargs.pop('scale_mlp_norm', False),
  1737. 'scale_attn_inner_norm': kwargs.pop('scale_attn_norm', False),
  1738. **kwargs # User overrides take precedence
  1739. }
  1740. return _create_naflexvit(variant, pretrained, **flex_kwargs)
  1741. def _create_naflexvit_from_eva(
  1742. variant: str,
  1743. pretrained: bool = False,
  1744. **kwargs,
  1745. ) -> NaFlexVit:
  1746. """Create NaFlexVit model from EVA configuration.
  1747. This function handles the parameter mapping and configuration logic needed
  1748. to create NaFlexVit models that are compatible with EVA configurations
  1749. and pretrained weights.
  1750. Args:
  1751. variant: Model variant name
  1752. pretrained: Whether to load pretrained weights
  1753. **kwargs: EVA model parameters
  1754. Returns:
  1755. NaFlexVit model instance
  1756. """
  1757. # Handle EVA's unique parameters & block args
  1758. kwargs.pop('no_embed_class', None) # EVA specific, not used in NaFlexVit (always no-embed)
  1759. # Map EVA's rope parameters
  1760. use_rot_pos_emb = kwargs.pop('use_rot_pos_emb', False)
  1761. rope_mixed_mode = kwargs.pop('rope_mixed_mode', False)
  1762. rope_temperature = kwargs.pop('rope_temperature', 10000.)
  1763. rope_grid_offset = kwargs.pop('rope_grid_offset', 0.)
  1764. rope_grid_indexing = kwargs.pop('rope_grid_indexing', 'ij')
  1765. if use_rot_pos_emb:
  1766. rope_type = 'mixed' if rope_mixed_mode else 'axial'
  1767. else:
  1768. rope_type = 'none'
  1769. # Handle norm/pool resolution logic to mirror EVA
  1770. gp = kwargs.pop('global_pool', 'avg')
  1771. use_pre_transformer_norm = kwargs.pop('use_pre_transformer_norm', False)
  1772. use_post_transformer_norm = kwargs.pop('use_post_transformer_norm', True)
  1773. use_fc_norm = kwargs.pop('use_fc_norm', None)
  1774. if use_fc_norm is None:
  1775. use_fc_norm = gp == 'avg' # default on if avg pool used
  1776. # Set NaFlexVit-specific parameters
  1777. naflex_kwargs = {
  1778. 'pos_embed_grid_size': None, # rely on img_size (// patch_size)
  1779. 'class_token': kwargs.get('class_token', True),
  1780. 'reg_tokens': kwargs.pop('num_reg_tokens', kwargs.get('reg_tokens', 0)),
  1781. 'global_pool': gp,
  1782. 'pre_norm': use_pre_transformer_norm,
  1783. 'final_norm': use_post_transformer_norm,
  1784. 'fc_norm': use_fc_norm,
  1785. 'pos_embed': 'learned' if kwargs.pop('use_abs_pos_emb', True) else 'none',
  1786. 'rope_type': rope_type,
  1787. 'rope_temperature': rope_temperature,
  1788. 'rope_grid_offset': rope_grid_offset,
  1789. 'rope_grid_indexing': rope_grid_indexing,
  1790. 'rope_ref_feat_shape': kwargs.get('ref_feat_shape', None),
  1791. 'attn_type': kwargs.pop('attn_type', 'eva'),
  1792. 'swiglu_mlp': kwargs.pop('swiglu_mlp', False),
  1793. 'qkv_fused': kwargs.pop('qkv_fused', True),
  1794. 'scale_mlp_norm': kwargs.pop('scale_mlp', False),
  1795. 'scale_attn_inner_norm': kwargs.pop('scale_attn_inner', False),
  1796. **kwargs # Pass remaining kwargs through
  1797. }
  1798. return _create_naflexvit(variant, pretrained, **naflex_kwargs)
  1799. @register_model
  1800. def naflexvit_base_patch16_gap(pretrained: bool = False, **kwargs) -> NaFlexVit:
  1801. """ViT-Base with NaFlex functionality and global average pooling.
  1802. """
  1803. cfg = NaFlexVitCfg(
  1804. patch_size=16,
  1805. embed_dim=768,
  1806. depth=12,
  1807. num_heads=12,
  1808. init_values=1e-5,
  1809. global_pool='avg',
  1810. reg_tokens=4,
  1811. fc_norm=True,
  1812. )
  1813. model = _create_naflexvit('naflexvit_base_patch16_gap', pretrained=pretrained, cfg=cfg, **kwargs)
  1814. return model
  1815. @register_model
  1816. def naflexvit_base_patch16_par_gap(pretrained: bool = False, **kwargs) -> NaFlexVit:
  1817. """ViT-Base with NaFlex functionality, aspect preserving pos embed, global average pooling.
  1818. """
  1819. cfg = NaFlexVitCfg(
  1820. patch_size=16,
  1821. embed_dim=768,
  1822. depth=12,
  1823. num_heads=12,
  1824. init_values=1e-5,
  1825. pos_embed_ar_preserving=True,
  1826. global_pool='avg',
  1827. reg_tokens=4,
  1828. fc_norm=True,
  1829. )
  1830. model = _create_naflexvit('naflexvit_base_patch16_par_gap', pretrained=pretrained, cfg=cfg, **kwargs)
  1831. return model
  1832. @register_model
  1833. def naflexvit_base_patch16_parfac_gap(pretrained: bool = False, **kwargs) -> NaFlexVit:
  1834. """ViT-Base with NaFlex functionality, aspect preserving & factorized pos embed, global average pooling.
  1835. """
  1836. cfg = NaFlexVitCfg(
  1837. patch_size=16,
  1838. embed_dim=768,
  1839. depth=12,
  1840. num_heads=12,
  1841. init_values=1e-5,
  1842. pos_embed_ar_preserving=True,
  1843. pos_embed='factorized',
  1844. global_pool='avg',
  1845. reg_tokens=4,
  1846. fc_norm=True,
  1847. )
  1848. model = _create_naflexvit('naflexvit_base_patch16_parfac_gap', pretrained=pretrained, cfg=cfg, **kwargs)
  1849. return model
  1850. @register_model
  1851. def naflexvit_base_patch16_map(pretrained: bool = False, **kwargs) -> NaFlexVit:
  1852. """ViT-Base with NaFlex functionality and MAP attention pooling.
  1853. """
  1854. cfg = NaFlexVitCfg(
  1855. patch_size=16,
  1856. embed_dim=768,
  1857. depth=12,
  1858. num_heads=12,
  1859. init_values=1e-5,
  1860. global_pool='map',
  1861. reg_tokens=1,
  1862. )
  1863. model = _create_naflexvit('naflexvit_base_patch16_map', pretrained=pretrained, cfg=cfg, **kwargs)
  1864. return model
  1865. @register_model
  1866. def naflexvit_so150m2_patch16_reg1_gap(pretrained: bool = False, **kwargs) -> NaFlexVit:
  1867. """ViT-SO150M2 with NaFlex functionality for variable aspect ratios and resolutions.
  1868. This model supports:
  1869. 1. Variable aspect ratios and resolutions via patch coordinates
  1870. 2. Position embedding interpolation for arbitrary grid sizes
  1871. 3. Explicit patch coordinates and valid token masking
  1872. """
  1873. cfg = NaFlexVitCfg(
  1874. patch_size=16,
  1875. embed_dim=832,
  1876. depth=21,
  1877. num_heads=13,
  1878. mlp_ratio=34/13,
  1879. init_values=1e-5,
  1880. qkv_bias=False,
  1881. reg_tokens=1,
  1882. global_pool='avg',
  1883. fc_norm=True,
  1884. )
  1885. model = _create_naflexvit('naflexvit_so150m2_patch16_reg1_gap', pretrained=pretrained, cfg=cfg, **kwargs)
  1886. return model
  1887. @register_model
  1888. def naflexvit_so150m2_patch16_reg1_map(pretrained: bool = False, **kwargs) -> NaFlexVit:
  1889. """ViT-SO150M2 with NaFlex functionality for variable aspect ratios and resolutions.
  1890. This model supports:
  1891. 1. Variable aspect ratios and resolutions via patch coordinates
  1892. 2. Position embedding interpolation for arbitrary grid sizes
  1893. 3. Explicit patch coordinates and valid token masking
  1894. """
  1895. cfg = NaFlexVitCfg(
  1896. patch_size=16,
  1897. embed_dim=832,
  1898. depth=21,
  1899. num_heads=13,
  1900. mlp_ratio=34/13,
  1901. init_values=1e-5,
  1902. qkv_bias=False,
  1903. reg_tokens=1,
  1904. global_pool='map',
  1905. )
  1906. model = _create_naflexvit('naflexvit_so150m2_patch16_reg1_map', pretrained=pretrained, cfg=cfg, **kwargs)
  1907. return model
  1908. @register_model
  1909. def naflexvit_base_patch16_siglip(pretrained: bool = False, **kwargs) -> NaFlexVit:
  1910. """ViT-Base with NaFlex functionality and SigLIP-style configuration.
  1911. """
  1912. cfg = NaFlexVitCfg(
  1913. patch_size=16,
  1914. embed_dim=768,
  1915. depth=12,
  1916. num_heads=12,
  1917. act_layer='gelu_tanh',
  1918. global_pool='map',
  1919. )
  1920. model = _create_naflexvit('naflexvit_base_patch16_siglip', pretrained=pretrained, cfg=cfg, **kwargs)
  1921. return model
  1922. @register_model
  1923. def naflexvit_so400m_patch16_siglip(pretrained: bool = False, **kwargs) -> NaFlexVit:
  1924. """ViT-SO400M with NaFlex functionality for variable aspect ratios and resolutions.
  1925. """
  1926. cfg = NaFlexVitCfg(
  1927. patch_size=16,
  1928. embed_dim=1152,
  1929. depth=27,
  1930. num_heads=16,
  1931. mlp_ratio=3.7362,
  1932. act_layer='gelu_tanh',
  1933. global_pool='map',
  1934. )
  1935. model = _create_naflexvit('naflexvit_so400m_patch16_siglip', pretrained=pretrained, cfg=cfg, **kwargs)
  1936. return model