volo.py 48 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403
  1. """ Vision OutLOoker (VOLO) implementation
  2. Paper: `VOLO: Vision Outlooker for Visual Recognition` - https://arxiv.org/abs/2106.13112
  3. Code adapted from official impl at https://github.com/sail-sg/volo, original copyright in comment below
  4. Modifications and additions for timm by / Copyright 2022, Ross Wightman
  5. """
  6. # Copyright 2021 Sea Limited.
  7. #
  8. # Licensed under the Apache License, Version 2.0 (the "License");
  9. # you may not use this file except in compliance with the License.
  10. # You may obtain a copy of the License at
  11. #
  12. # http://www.apache.org/licenses/LICENSE-2.0
  13. #
  14. # Unless required by applicable law or agreed to in writing, software
  15. # distributed under the License is distributed on an "AS IS" BASIS,
  16. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  17. # See the License for the specific language governing permissions and
  18. # limitations under the License.
  19. import math
  20. from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type
  21. import torch
  22. import torch.nn as nn
  23. import torch.nn.functional as F
  24. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  25. from timm.layers import DropPath, Mlp, to_2tuple, to_ntuple, trunc_normal_, use_fused_attn
  26. from ._builder import build_model_with_cfg
  27. from ._features import feature_take_indices
  28. from ._manipulate import checkpoint
  29. from ._registry import register_model, generate_default_cfgs
  30. __all__ = ['VOLO'] # model_registry will add each entrypoint fn to this
  31. class OutlookAttention(nn.Module):
  32. """Outlook attention mechanism for VOLO models."""
  33. def __init__(
  34. self,
  35. dim: int,
  36. num_heads: int,
  37. kernel_size: int = 3,
  38. padding: int = 1,
  39. stride: int = 1,
  40. qkv_bias: bool = False,
  41. attn_drop: float = 0.,
  42. proj_drop: float = 0.,
  43. device=None,
  44. dtype=None,
  45. ):
  46. """Initialize OutlookAttention.
  47. Args:
  48. dim: Input feature dimension.
  49. num_heads: Number of attention heads.
  50. kernel_size: Kernel size for attention computation.
  51. padding: Padding for attention computation.
  52. stride: Stride for attention computation.
  53. qkv_bias: Whether to use bias in linear layers.
  54. attn_drop: Attention dropout rate.
  55. proj_drop: Projection dropout rate.
  56. """
  57. dd = {'device': device, 'dtype': dtype}
  58. super().__init__()
  59. head_dim = dim // num_heads
  60. self.num_heads = num_heads
  61. self.kernel_size = kernel_size
  62. self.padding = padding
  63. self.stride = stride
  64. self.scale = head_dim ** -0.5
  65. self.v = nn.Linear(dim, dim, bias=qkv_bias, **dd)
  66. self.attn = nn.Linear(dim, kernel_size ** 4 * num_heads, **dd)
  67. self.attn_drop = nn.Dropout(attn_drop)
  68. self.proj = nn.Linear(dim, dim, **dd)
  69. self.proj_drop = nn.Dropout(proj_drop)
  70. self.unfold = nn.Unfold(kernel_size=kernel_size, padding=padding, stride=stride)
  71. self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True)
  72. def forward(self, x: torch.Tensor) -> torch.Tensor:
  73. """Forward pass.
  74. Args:
  75. x: Input tensor of shape (B, H, W, C).
  76. Returns:
  77. Output tensor of shape (B, H, W, C).
  78. """
  79. B, H, W, C = x.shape
  80. v = self.v(x).permute(0, 3, 1, 2) # B, C, H, W
  81. h, w = math.ceil(H / self.stride), math.ceil(W / self.stride)
  82. v = self.unfold(v).reshape(
  83. B, self.num_heads, C // self.num_heads,
  84. self.kernel_size * self.kernel_size, h * w).permute(0, 1, 4, 3, 2) # B,H,N,kxk,C/H
  85. attn = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
  86. attn = self.attn(attn).reshape(
  87. B, h * w, self.num_heads, self.kernel_size * self.kernel_size,
  88. self.kernel_size * self.kernel_size).permute(0, 2, 1, 3, 4) # B,H,N,kxk,kxk
  89. attn = attn * self.scale
  90. attn = attn.softmax(dim=-1)
  91. attn = self.attn_drop(attn)
  92. x = (attn @ v).permute(0, 1, 4, 3, 2).reshape(B, C * self.kernel_size * self.kernel_size, h * w)
  93. x = F.fold(x, output_size=(H, W), kernel_size=self.kernel_size, padding=self.padding, stride=self.stride)
  94. x = self.proj(x.permute(0, 2, 3, 1))
  95. x = self.proj_drop(x)
  96. return x
  97. class Outlooker(nn.Module):
  98. """Outlooker block that combines outlook attention with MLP."""
  99. def __init__(
  100. self,
  101. dim: int,
  102. kernel_size: int,
  103. padding: int,
  104. stride: int = 1,
  105. num_heads: int = 1,
  106. mlp_ratio: float = 3.,
  107. attn_drop: float = 0.,
  108. drop_path: float = 0.,
  109. act_layer: Type[nn.Module] = nn.GELU,
  110. norm_layer: Type[nn.Module] = nn.LayerNorm,
  111. qkv_bias: bool = False,
  112. device=None,
  113. dtype=None,
  114. ):
  115. """Initialize Outlooker block.
  116. Args:
  117. dim: Input feature dimension.
  118. kernel_size: Kernel size for outlook attention.
  119. padding: Padding for outlook attention.
  120. stride: Stride for outlook attention.
  121. num_heads: Number of attention heads.
  122. mlp_ratio: Ratio for MLP hidden dimension.
  123. attn_drop: Attention dropout rate.
  124. drop_path: Stochastic depth drop rate.
  125. act_layer: Activation layer type.
  126. norm_layer: Normalization layer type.
  127. qkv_bias: Whether to use bias in linear layers.
  128. """
  129. dd = {'device': device, 'dtype': dtype}
  130. super().__init__()
  131. self.norm1 = norm_layer(dim, **dd)
  132. self.attn = OutlookAttention(
  133. dim,
  134. num_heads,
  135. kernel_size=kernel_size,
  136. padding=padding,
  137. stride=stride,
  138. qkv_bias=qkv_bias,
  139. attn_drop=attn_drop,
  140. **dd,
  141. )
  142. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  143. self.norm2 = norm_layer(dim, **dd)
  144. self.mlp = Mlp(
  145. in_features=dim,
  146. hidden_features=int(dim * mlp_ratio),
  147. act_layer=act_layer,
  148. **dd,
  149. )
  150. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  151. def forward(self, x: torch.Tensor) -> torch.Tensor:
  152. """Forward pass.
  153. Args:
  154. x: Input tensor.
  155. Returns:
  156. Output tensor.
  157. """
  158. x = x + self.drop_path1(self.attn(self.norm1(x)))
  159. x = x + self.drop_path2(self.mlp(self.norm2(x)))
  160. return x
  161. class Attention(nn.Module):
  162. """Multi-head self-attention module."""
  163. fused_attn: torch.jit.Final[bool]
  164. def __init__(
  165. self,
  166. dim: int,
  167. num_heads: int = 8,
  168. qkv_bias: bool = False,
  169. attn_drop: float = 0.,
  170. proj_drop: float = 0.,
  171. device=None,
  172. dtype=None,
  173. ):
  174. """Initialize Attention module.
  175. Args:
  176. dim: Input feature dimension.
  177. num_heads: Number of attention heads.
  178. qkv_bias: Whether to use bias in QKV projection.
  179. attn_drop: Attention dropout rate.
  180. proj_drop: Projection dropout rate.
  181. """
  182. dd = {'device': device, 'dtype': dtype}
  183. super().__init__()
  184. self.num_heads = num_heads
  185. head_dim = dim // num_heads
  186. self.scale = head_dim ** -0.5
  187. self.fused_attn = use_fused_attn()
  188. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **dd)
  189. self.attn_drop = nn.Dropout(attn_drop)
  190. self.proj = nn.Linear(dim, dim, **dd)
  191. self.proj_drop = nn.Dropout(proj_drop)
  192. def forward(self, x: torch.Tensor) -> torch.Tensor:
  193. """Forward pass.
  194. Args:
  195. x: Input tensor of shape (B, H, W, C).
  196. Returns:
  197. Output tensor of shape (B, H, W, C).
  198. """
  199. B, H, W, C = x.shape
  200. qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  201. q, k, v = qkv.unbind(0)
  202. if self.fused_attn:
  203. x = F.scaled_dot_product_attention(
  204. q, k, v,
  205. dropout_p=self.attn_drop.p if self.training else 0.,
  206. )
  207. else:
  208. q = q * self.scale
  209. attn = q @ k.transpose(-2, -1)
  210. attn = attn.softmax(dim=-1)
  211. attn = self.attn_drop(attn)
  212. x = attn @ v
  213. x = x.transpose(1, 2).reshape(B, H, W, C)
  214. x = self.proj(x)
  215. x = self.proj_drop(x)
  216. return x
  217. class Transformer(nn.Module):
  218. """Transformer block with multi-head self-attention and MLP."""
  219. def __init__(
  220. self,
  221. dim: int,
  222. num_heads: int,
  223. mlp_ratio: float = 4.,
  224. qkv_bias: bool = False,
  225. attn_drop: float = 0.,
  226. drop_path: float = 0.,
  227. act_layer: Type[nn.Module] = nn.GELU,
  228. norm_layer: Type[nn.Module] = nn.LayerNorm,
  229. device=None,
  230. dtype=None,
  231. ):
  232. """Initialize Transformer block.
  233. Args:
  234. dim: Input feature dimension.
  235. num_heads: Number of attention heads.
  236. mlp_ratio: Ratio for MLP hidden dimension.
  237. qkv_bias: Whether to use bias in QKV projection.
  238. attn_drop: Attention dropout rate.
  239. drop_path: Stochastic depth drop rate.
  240. act_layer: Activation layer type.
  241. norm_layer: Normalization layer type.
  242. """
  243. dd = {'device': device, 'dtype': dtype}
  244. super().__init__()
  245. self.norm1 = norm_layer(dim, **dd)
  246. self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, **dd)
  247. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  248. self.norm2 = norm_layer(dim, **dd)
  249. self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, **dd)
  250. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  251. def forward(self, x: torch.Tensor) -> torch.Tensor:
  252. """Forward pass.
  253. Args:
  254. x: Input tensor.
  255. Returns:
  256. Output tensor.
  257. """
  258. x = x + self.drop_path1(self.attn(self.norm1(x)))
  259. x = x + self.drop_path2(self.mlp(self.norm2(x)))
  260. return x
  261. class ClassAttention(nn.Module):
  262. """Class attention mechanism for class token interaction."""
  263. def __init__(
  264. self,
  265. dim: int,
  266. num_heads: int = 8,
  267. head_dim: Optional[int] = None,
  268. qkv_bias: bool = False,
  269. attn_drop: float = 0.,
  270. proj_drop: float = 0.,
  271. device=None,
  272. dtype=None,
  273. ):
  274. """Initialize ClassAttention.
  275. Args:
  276. dim: Input feature dimension.
  277. num_heads: Number of attention heads.
  278. head_dim: Dimension per head. If None, computed as dim // num_heads.
  279. qkv_bias: Whether to use bias in QKV projection.
  280. attn_drop: Attention dropout rate.
  281. proj_drop: Projection dropout rate.
  282. """
  283. dd = {'device': device, 'dtype': dtype}
  284. super().__init__()
  285. self.num_heads = num_heads
  286. if head_dim is not None:
  287. self.head_dim = head_dim
  288. else:
  289. head_dim = dim // num_heads
  290. self.head_dim = head_dim
  291. self.scale = head_dim ** -0.5
  292. self.kv = nn.Linear(dim, self.head_dim * self.num_heads * 2, bias=qkv_bias, **dd)
  293. self.q = nn.Linear(dim, self.head_dim * self.num_heads, bias=qkv_bias, **dd)
  294. self.attn_drop = nn.Dropout(attn_drop)
  295. self.proj = nn.Linear(self.head_dim * self.num_heads, dim, **dd)
  296. self.proj_drop = nn.Dropout(proj_drop)
  297. def forward(self, x: torch.Tensor) -> torch.Tensor:
  298. """Forward pass.
  299. Args:
  300. x: Input tensor of shape (B, N, C) where first token is class token.
  301. Returns:
  302. Class token output of shape (B, 1, C).
  303. """
  304. B, N, C = x.shape
  305. kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
  306. k, v = kv.unbind(0)
  307. q = self.q(x[:, :1, :]).reshape(B, self.num_heads, 1, self.head_dim) * self.scale
  308. attn = q @ k.transpose(-2, -1)
  309. attn = attn.softmax(dim=-1)
  310. attn = self.attn_drop(attn)
  311. cls_embed = (attn @ v).transpose(1, 2).reshape(B, 1, self.head_dim * self.num_heads)
  312. cls_embed = self.proj(cls_embed)
  313. cls_embed = self.proj_drop(cls_embed)
  314. return cls_embed
  315. class ClassBlock(nn.Module):
  316. """Class block that combines class attention with MLP."""
  317. def __init__(
  318. self,
  319. dim: int,
  320. num_heads: int,
  321. head_dim: Optional[int] = None,
  322. mlp_ratio: float = 4.,
  323. qkv_bias: bool = False,
  324. drop: float = 0.,
  325. attn_drop: float = 0.,
  326. drop_path: float = 0.,
  327. act_layer: Type[nn.Module] = nn.GELU,
  328. norm_layer: Type[nn.Module] = nn.LayerNorm,
  329. device=None,
  330. dtype=None,
  331. ):
  332. """Initialize ClassBlock.
  333. Args:
  334. dim: Input feature dimension.
  335. num_heads: Number of attention heads.
  336. head_dim: Dimension per head. If None, computed as dim // num_heads.
  337. mlp_ratio: Ratio for MLP hidden dimension.
  338. qkv_bias: Whether to use bias in QKV projection.
  339. drop: Dropout rate.
  340. attn_drop: Attention dropout rate.
  341. drop_path: Stochastic depth drop rate.
  342. act_layer: Activation layer type.
  343. norm_layer: Normalization layer type.
  344. """
  345. dd = {'device': device, 'dtype': dtype}
  346. super().__init__()
  347. self.norm1 = norm_layer(dim, **dd)
  348. self.attn = ClassAttention(
  349. dim,
  350. num_heads=num_heads,
  351. head_dim=head_dim,
  352. qkv_bias=qkv_bias,
  353. attn_drop=attn_drop,
  354. proj_drop=drop,
  355. **dd,
  356. )
  357. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  358. self.norm2 = norm_layer(dim, **dd)
  359. self.mlp = Mlp(
  360. in_features=dim,
  361. hidden_features=int(dim * mlp_ratio),
  362. act_layer=act_layer,
  363. drop=drop,
  364. **dd,
  365. )
  366. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  367. def forward(self, x: torch.Tensor) -> torch.Tensor:
  368. """Forward pass.
  369. Args:
  370. x: Input tensor of shape (B, N, C) where first token is class token.
  371. Returns:
  372. Output tensor with updated class token.
  373. """
  374. cls_embed = x[:, :1]
  375. cls_embed = cls_embed + self.drop_path1(self.attn(self.norm1(x)))
  376. cls_embed = cls_embed + self.drop_path2(self.mlp(self.norm2(cls_embed)))
  377. return torch.cat([cls_embed, x[:, 1:]], dim=1)
  378. def get_block(block_type: str, **kwargs: Any) -> nn.Module:
  379. """Get block based on type.
  380. Args:
  381. block_type: Type of block ('ca' for ClassBlock).
  382. **kwargs: Additional keyword arguments for block.
  383. Returns:
  384. The requested block module.
  385. """
  386. if block_type == 'ca':
  387. return ClassBlock(**kwargs)
  388. else:
  389. assert False, f'Invalid block type: {block_type}'
  390. def rand_bbox(size: Tuple[int, ...], lam: float, scale: int = 1) -> Tuple[int, int, int, int]:
  391. """Get random bounding box for token labeling.
  392. Reference: https://github.com/zihangJiang/TokenLabeling
  393. Args:
  394. size: Input tensor size tuple.
  395. lam: Lambda parameter for cutmix.
  396. scale: Scaling factor.
  397. Returns:
  398. Bounding box coordinates (bbx1, bby1, bbx2, bby2).
  399. """
  400. W = size[1] // scale
  401. H = size[2] // scale
  402. W_t = torch.tensor(W, dtype=torch.float32)
  403. H_t = torch.tensor(H, dtype=torch.float32)
  404. cut_rat = torch.sqrt(1. - lam)
  405. cut_w = (W_t * cut_rat).int()
  406. cut_h = (H_t * cut_rat).int()
  407. # uniform
  408. cx = torch.randint(0, W, (1,))
  409. cy = torch.randint(0, H, (1,))
  410. bbx1 = torch.clamp(cx - cut_w // 2, 0, W)
  411. bby1 = torch.clamp(cy - cut_h // 2, 0, H)
  412. bbx2 = torch.clamp(cx + cut_w // 2, 0, W)
  413. bby2 = torch.clamp(cy + cut_h // 2, 0, H)
  414. return bbx1.item(), bby1.item(), bbx2.item(), bby2.item()
  415. class PatchEmbed(nn.Module):
  416. """Image to patch embedding with multi-layer convolution."""
  417. def __init__(
  418. self,
  419. img_size: int = 224,
  420. stem_conv: bool = False,
  421. stem_stride: int = 1,
  422. patch_size: int = 8,
  423. in_chans: int = 3,
  424. hidden_dim: int = 64,
  425. embed_dim: int = 384,
  426. device=None,
  427. dtype=None,
  428. ):
  429. """Initialize PatchEmbed.
  430. Different from ViT which uses 1 conv layer, VOLO uses multiple conv layers for patch embedding.
  431. Args:
  432. img_size: Input image size.
  433. stem_conv: Whether to use stem convolution layers.
  434. stem_stride: Stride for stem convolution.
  435. patch_size: Patch size (must be 4, 8, or 16).
  436. in_chans: Number of input channels.
  437. hidden_dim: Hidden dimension for stem convolution.
  438. embed_dim: Output embedding dimension.
  439. """
  440. dd = {'device': device, 'dtype': dtype}
  441. super().__init__()
  442. assert patch_size in [4, 8, 16]
  443. if stem_conv:
  444. self.conv = nn.Sequential(
  445. nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False, **dd),
  446. nn.BatchNorm2d(hidden_dim, **dd),
  447. nn.ReLU(inplace=True),
  448. nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False, **dd),
  449. nn.BatchNorm2d(hidden_dim, **dd),
  450. nn.ReLU(inplace=True),
  451. nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False, **dd),
  452. nn.BatchNorm2d(hidden_dim, **dd),
  453. nn.ReLU(inplace=True),
  454. )
  455. else:
  456. self.conv = None
  457. self.proj = nn.Conv2d(
  458. hidden_dim,
  459. embed_dim,
  460. kernel_size=patch_size // stem_stride,
  461. stride=patch_size // stem_stride,
  462. **dd,
  463. )
  464. self.num_patches = (img_size // patch_size) * (img_size // patch_size)
  465. def forward(self, x: torch.Tensor) -> torch.Tensor:
  466. """Forward pass.
  467. Args:
  468. x: Input tensor of shape (B, C, H, W).
  469. Returns:
  470. Output tensor of shape (B, embed_dim, H', W').
  471. """
  472. if self.conv is not None:
  473. x = self.conv(x)
  474. x = self.proj(x) # B, C, H, W
  475. return x
  476. class Downsample(nn.Module):
  477. """Downsampling module between stages."""
  478. def __init__(
  479. self,
  480. in_embed_dim: int,
  481. out_embed_dim: int,
  482. patch_size: int = 2,
  483. device=None,
  484. dtype=None,
  485. ):
  486. """Initialize Downsample.
  487. Args:
  488. in_embed_dim: Input embedding dimension.
  489. out_embed_dim: Output embedding dimension.
  490. patch_size: Patch size for downsampling.
  491. """
  492. super().__init__()
  493. dd = {'device': device, 'dtype': dtype}
  494. self.proj = nn.Conv2d(in_embed_dim, out_embed_dim, kernel_size=patch_size, stride=patch_size, **dd)
  495. def forward(self, x: torch.Tensor) -> torch.Tensor:
  496. """Forward pass.
  497. Args:
  498. x: Input tensor of shape (B, H, W, C).
  499. Returns:
  500. Output tensor of shape (B, H', W', C').
  501. """
  502. x = x.permute(0, 3, 1, 2)
  503. x = self.proj(x) # B, C, H, W
  504. x = x.permute(0, 2, 3, 1)
  505. return x
  506. def outlooker_blocks(
  507. block_fn: Callable,
  508. index: int,
  509. dim: int,
  510. layers: List[int],
  511. num_heads: int = 1,
  512. kernel_size: int = 3,
  513. padding: int = 1,
  514. stride: int = 2,
  515. mlp_ratio: float = 3.,
  516. qkv_bias: bool = False,
  517. attn_drop: float = 0,
  518. drop_path_rate: float = 0.,
  519. device=None,
  520. dtype=None,
  521. **kwargs: Any,
  522. ) -> nn.Sequential:
  523. """Generate outlooker layers for stage 1.
  524. Args:
  525. block_fn: Block function to use (typically Outlooker).
  526. index: Index of current stage.
  527. dim: Feature dimension.
  528. layers: List of layer counts for each stage.
  529. num_heads: Number of attention heads.
  530. kernel_size: Kernel size for outlook attention.
  531. padding: Padding for outlook attention.
  532. stride: Stride for outlook attention.
  533. mlp_ratio: Ratio for MLP hidden dimension.
  534. qkv_bias: Whether to use bias in QKV projection.
  535. attn_drop: Attention dropout rate.
  536. drop_path_rate: Stochastic depth drop rate.
  537. **kwargs: Additional keyword arguments.
  538. Returns:
  539. Sequential module containing outlooker blocks.
  540. """
  541. blocks = []
  542. for block_idx in range(layers[index]):
  543. block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
  544. blocks.append(block_fn(
  545. dim,
  546. kernel_size=kernel_size,
  547. padding=padding,
  548. stride=stride,
  549. num_heads=num_heads,
  550. mlp_ratio=mlp_ratio,
  551. qkv_bias=qkv_bias,
  552. attn_drop=attn_drop,
  553. drop_path=block_dpr,
  554. device=device,
  555. dtype=dtype,
  556. **kwargs,
  557. ))
  558. blocks = nn.Sequential(*blocks)
  559. return blocks
  560. def transformer_blocks(
  561. block_fn: Callable,
  562. index: int,
  563. dim: int,
  564. layers: List[int],
  565. num_heads: int,
  566. mlp_ratio: float = 3.,
  567. qkv_bias: bool = False,
  568. attn_drop: float = 0,
  569. drop_path_rate: float = 0.,
  570. **kwargs: Any,
  571. ) -> nn.Sequential:
  572. """Generate transformer layers for stage 2.
  573. Args:
  574. block_fn: Block function to use (typically Transformer).
  575. index: Index of current stage.
  576. dim: Feature dimension.
  577. layers: List of layer counts for each stage.
  578. num_heads: Number of attention heads.
  579. mlp_ratio: Ratio for MLP hidden dimension.
  580. qkv_bias: Whether to use bias in QKV projection.
  581. attn_drop: Attention dropout rate.
  582. drop_path_rate: Stochastic depth drop rate.
  583. **kwargs: Additional keyword arguments.
  584. Returns:
  585. Sequential module containing transformer blocks.
  586. """
  587. blocks = []
  588. for block_idx in range(layers[index]):
  589. block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
  590. blocks.append(block_fn(
  591. dim,
  592. num_heads,
  593. mlp_ratio=mlp_ratio,
  594. qkv_bias=qkv_bias,
  595. attn_drop=attn_drop,
  596. drop_path=block_dpr,
  597. **kwargs,
  598. ))
  599. blocks = nn.Sequential(*blocks)
  600. return blocks
  601. class VOLO(nn.Module):
  602. """Vision Outlooker (VOLO) model."""
  603. def __init__(
  604. self,
  605. layers: List[int],
  606. img_size: int = 224,
  607. in_chans: int = 3,
  608. num_classes: int = 1000,
  609. global_pool: str = 'token',
  610. patch_size: int = 8,
  611. stem_hidden_dim: int = 64,
  612. embed_dims: Optional[List[int]] = None,
  613. num_heads: Optional[List[int]] = None,
  614. downsamples: Tuple[bool, ...] = (True, False, False, False),
  615. outlook_attention: Tuple[bool, ...] = (True, False, False, False),
  616. mlp_ratio: float = 3.0,
  617. qkv_bias: bool = False,
  618. drop_rate: float = 0.,
  619. pos_drop_rate: float = 0.,
  620. attn_drop_rate: float = 0.,
  621. drop_path_rate: float = 0.,
  622. norm_layer: Type[nn.Module] = nn.LayerNorm,
  623. post_layers: Optional[Tuple[str, ...]] = ('ca', 'ca'),
  624. use_aux_head: bool = True,
  625. use_mix_token: bool = False,
  626. pooling_scale: int = 2,
  627. device=None,
  628. dtype=None,
  629. ):
  630. """Initialize VOLO model.
  631. Args:
  632. layers: Number of blocks in each stage.
  633. img_size: Input image size.
  634. in_chans: Number of input channels.
  635. num_classes: Number of classes for classification.
  636. global_pool: Global pooling type ('token', 'avg', or '').
  637. patch_size: Patch size for patch embedding.
  638. stem_hidden_dim: Hidden dimension for stem convolution.
  639. embed_dims: List of embedding dimensions for each stage.
  640. num_heads: List of number of attention heads for each stage.
  641. downsamples: Whether to downsample between stages.
  642. outlook_attention: Whether to use outlook attention in each stage.
  643. mlp_ratio: Ratio for MLP hidden dimension.
  644. qkv_bias: Whether to use bias in QKV projection.
  645. drop_rate: Dropout rate.
  646. pos_drop_rate: Position embedding dropout rate.
  647. attn_drop_rate: Attention dropout rate.
  648. drop_path_rate: Stochastic depth drop rate.
  649. norm_layer: Normalization layer type.
  650. post_layers: Post-processing layer types.
  651. use_aux_head: Whether to use auxiliary head.
  652. use_mix_token: Whether to use token mixing for training.
  653. pooling_scale: Pooling scale factor.
  654. """
  655. super().__init__()
  656. dd = {'device': device, 'dtype': dtype}
  657. num_layers = len(layers)
  658. mlp_ratio = to_ntuple(num_layers)(mlp_ratio)
  659. img_size = to_2tuple(img_size)
  660. self.num_classes = num_classes
  661. self.global_pool = global_pool
  662. self.mix_token = use_mix_token
  663. self.pooling_scale = pooling_scale
  664. self.num_features = self.head_hidden_size = embed_dims[-1]
  665. if use_mix_token: # enable token mixing, see token labeling for details.
  666. self.beta = 1.0
  667. assert global_pool == 'token', "return all tokens if mix_token is enabled"
  668. self.grad_checkpointing = False
  669. self.patch_embed = PatchEmbed(
  670. stem_conv=True,
  671. stem_stride=2,
  672. patch_size=patch_size,
  673. in_chans=in_chans,
  674. hidden_dim=stem_hidden_dim,
  675. embed_dim=embed_dims[0],
  676. **dd,
  677. )
  678. r = patch_size
  679. # initial positional encoding, we add positional encoding after outlooker blocks
  680. patch_grid = (img_size[0] // patch_size // pooling_scale, img_size[1] // patch_size // pooling_scale)
  681. self.pos_embed = nn.Parameter(torch.zeros(1, patch_grid[0], patch_grid[1], embed_dims[-1], **dd))
  682. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  683. # set the main block in network
  684. self.stage_ends = []
  685. self.feature_info = []
  686. network = []
  687. block_idx = 0
  688. for i in range(len(layers)):
  689. if outlook_attention[i]:
  690. # stage 1
  691. stage = outlooker_blocks(
  692. Outlooker,
  693. i,
  694. embed_dims[i],
  695. layers,
  696. num_heads[i],
  697. mlp_ratio=mlp_ratio[i],
  698. qkv_bias=qkv_bias,
  699. attn_drop=attn_drop_rate,
  700. norm_layer=norm_layer,
  701. **dd,
  702. )
  703. else:
  704. # stage 2
  705. stage = transformer_blocks(
  706. Transformer,
  707. i,
  708. embed_dims[i],
  709. layers,
  710. num_heads[i],
  711. mlp_ratio=mlp_ratio[i],
  712. qkv_bias=qkv_bias,
  713. drop_path_rate=drop_path_rate,
  714. attn_drop=attn_drop_rate,
  715. norm_layer=norm_layer,
  716. **dd,
  717. )
  718. network.append(stage)
  719. self.stage_ends.append(block_idx)
  720. self.feature_info.append(dict(num_chs=embed_dims[i], reduction=r, module=f'network.{block_idx}'))
  721. block_idx += 1
  722. if downsamples[i]:
  723. # downsampling between two stages
  724. network.append(Downsample(embed_dims[i], embed_dims[i + 1], 2, **dd))
  725. r *= 2
  726. block_idx += 1
  727. self.network = nn.ModuleList(network)
  728. # set post block, for example, class attention layers
  729. self.post_network = None
  730. if post_layers is not None:
  731. self.post_network = nn.ModuleList([
  732. get_block(
  733. post_layers[i],
  734. dim=embed_dims[-1],
  735. num_heads=num_heads[-1],
  736. mlp_ratio=mlp_ratio[-1],
  737. qkv_bias=qkv_bias,
  738. attn_drop=attn_drop_rate,
  739. drop_path=0.,
  740. norm_layer=norm_layer,
  741. **dd,
  742. )
  743. for i in range(len(post_layers))
  744. ])
  745. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[-1], **dd))
  746. trunc_normal_(self.cls_token, std=.02)
  747. # set output type
  748. if use_aux_head:
  749. self.aux_head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
  750. else:
  751. self.aux_head = None
  752. self.norm = norm_layer(self.num_features, **dd)
  753. # Classifier head
  754. self.head_drop = nn.Dropout(drop_rate)
  755. self.head = nn.Linear(self.num_features, num_classes, **dd) if num_classes > 0 else nn.Identity()
  756. trunc_normal_(self.pos_embed, std=.02)
  757. self.apply(self._init_weights)
  758. def _init_weights(self, m: nn.Module) -> None:
  759. """Initialize weights for modules.
  760. Args:
  761. m: Module to initialize.
  762. """
  763. if isinstance(m, nn.Linear):
  764. trunc_normal_(m.weight, std=.02)
  765. if isinstance(m, nn.Linear) and m.bias is not None:
  766. nn.init.constant_(m.bias, 0)
  767. @torch.jit.ignore
  768. def no_weight_decay(self) -> set:
  769. """Get set of parameters that should not have weight decay.
  770. Returns:
  771. Set of parameter names.
  772. """
  773. return {'pos_embed', 'cls_token'}
  774. @torch.jit.ignore
  775. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  776. """Get parameter grouping for optimizer.
  777. Args:
  778. coarse: Whether to use coarse grouping.
  779. Returns:
  780. Parameter grouping dictionary.
  781. """
  782. return dict(
  783. stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
  784. blocks=[
  785. (r'^network\.(\d+)\.(\d+)', None),
  786. (r'^network\.(\d+)', (0,)),
  787. ],
  788. blocks2=[
  789. (r'^cls_token', (0,)),
  790. (r'^post_network\.(\d+)', None),
  791. (r'^norm', (99999,))
  792. ],
  793. )
  794. @torch.jit.ignore
  795. def set_grad_checkpointing(self, enable: bool = True) -> None:
  796. """Set gradient checkpointing.
  797. Args:
  798. enable: Whether to enable gradient checkpointing.
  799. """
  800. self.grad_checkpointing = enable
  801. @torch.jit.ignore
  802. def get_classifier(self) -> nn.Module:
  803. """Get classifier module.
  804. Returns:
  805. The classifier head module.
  806. """
  807. return self.head
  808. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  809. """Reset classifier head.
  810. Args:
  811. num_classes: Number of classes for new classifier.
  812. global_pool: Global pooling type.
  813. """
  814. self.num_classes = num_classes
  815. if global_pool is not None:
  816. self.global_pool = global_pool
  817. device = self.head.weight.device if hasattr(self.head, 'weight') else None
  818. dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None
  819. self.head = nn.Linear(
  820. self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()
  821. if self.aux_head is not None:
  822. self.aux_head = nn.Linear(
  823. self.num_features, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()
  824. def forward_tokens(self, x: torch.Tensor) -> torch.Tensor:
  825. """Forward pass through token processing stages.
  826. Args:
  827. x: Input tensor of shape (B, H, W, C).
  828. Returns:
  829. Token tensor of shape (B, N, C).
  830. """
  831. for idx, block in enumerate(self.network):
  832. if idx == 2:
  833. # add positional encoding after outlooker blocks
  834. x = x + self.pos_embed
  835. x = self.pos_drop(x)
  836. if self.grad_checkpointing and not torch.jit.is_scripting():
  837. x = checkpoint(block, x)
  838. else:
  839. x = block(x)
  840. B, H, W, C = x.shape
  841. x = x.reshape(B, -1, C)
  842. return x
  843. def forward_cls(self, x: torch.Tensor) -> torch.Tensor:
  844. """Forward pass through class attention blocks.
  845. Args:
  846. x: Input token tensor of shape (B, N, C).
  847. Returns:
  848. Output tensor with class token of shape (B, N+1, C).
  849. """
  850. B, N, C = x.shape
  851. cls_tokens = self.cls_token.expand(B, -1, -1)
  852. x = torch.cat([cls_tokens, x], dim=1)
  853. for block in self.post_network:
  854. if self.grad_checkpointing and not torch.jit.is_scripting():
  855. x = checkpoint(block, x)
  856. else:
  857. x = block(x)
  858. return x
  859. def forward_train(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, Tuple[int, int, int, int]]]:
  860. """Forward pass for training with mix token support.
  861. Args:
  862. x: Input tensor of shape (B, C, H, W).
  863. Returns:
  864. If training with mix_token: tuple of (class_token, aux_tokens, bbox).
  865. Otherwise: class_token tensor.
  866. """
  867. """ A separate forward fn for training with mix_token (if a train script supports).
  868. Combining multiple modes in as single forward with different return types is torchscript hell.
  869. """
  870. x = self.patch_embed(x)
  871. x = x.permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
  872. # mix token, see token labeling for details.
  873. if self.mix_token and self.training:
  874. lam = torch.distributions.Beta(self.beta, self.beta).sample()
  875. patch_h, patch_w = x.shape[1] // self.pooling_scale, x.shape[2] // self.pooling_scale
  876. bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam, scale=self.pooling_scale)
  877. temp_x = x.clone()
  878. sbbx1, sbby1 = self.pooling_scale * bbx1, self.pooling_scale * bby1
  879. sbbx2, sbby2 = self.pooling_scale * bbx2, self.pooling_scale * bby2
  880. temp_x[:, sbbx1:sbbx2, sbby1:sbby2, :] = x.flip(0)[:, sbbx1:sbbx2, sbby1:sbby2, :]
  881. x = temp_x
  882. else:
  883. bbx1, bby1, bbx2, bby2 = 0, 0, 0, 0
  884. # step2: tokens learning in the two stages
  885. x = self.forward_tokens(x)
  886. # step3: post network, apply class attention or not
  887. if self.post_network is not None:
  888. x = self.forward_cls(x)
  889. x = self.norm(x)
  890. if self.global_pool == 'avg':
  891. x_cls = x.mean(dim=1)
  892. elif self.global_pool == 'token':
  893. x_cls = x[:, 0]
  894. else:
  895. x_cls = x
  896. if self.aux_head is None:
  897. return x_cls
  898. x_aux = self.aux_head(x[:, 1:]) # generate classes in all feature tokens, see token labeling
  899. if not self.training:
  900. return x_cls + 0.5 * x_aux.max(1)[0]
  901. if self.mix_token and self.training: # reverse "mix token", see token labeling for details.
  902. x_aux = x_aux.reshape(x_aux.shape[0], patch_h, patch_w, x_aux.shape[-1])
  903. temp_x = x_aux.clone()
  904. temp_x[:, bbx1:bbx2, bby1:bby2, :] = x_aux.flip(0)[:, bbx1:bbx2, bby1:bby2, :]
  905. x_aux = temp_x
  906. x_aux = x_aux.reshape(x_aux.shape[0], patch_h * patch_w, x_aux.shape[-1])
  907. # return these: 1. class token, 2. classes from all feature tokens, 3. bounding box
  908. return x_cls, x_aux, (bbx1, bby1, bbx2, bby2)
  909. def forward_intermediates(
  910. self,
  911. x: torch.Tensor,
  912. indices: Optional[Union[int, List[int]]] = None,
  913. norm: bool = False,
  914. stop_early: bool = False,
  915. output_fmt: str = 'NCHW',
  916. intermediates_only: bool = False,
  917. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  918. """ Forward features that returns intermediates.
  919. Args:
  920. x: Input image tensor
  921. indices: Take last n blocks if int, all if None, select matching indices if sequence
  922. norm: Apply norm layer to all intermediates
  923. stop_early: Stop iterating over blocks when last desired intermediate hit
  924. output_fmt: Shape of intermediate feature outputs
  925. intermediates_only: Only return intermediate features
  926. Returns:
  927. """
  928. assert output_fmt in ('NCHW',), 'Output format must be NCHW.'
  929. intermediates = []
  930. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  931. take_indices = [self.stage_ends[i] for i in take_indices]
  932. max_index = self.stage_ends[max_index]
  933. # forward pass
  934. B, _, height, width = x.shape
  935. x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
  936. # step2: tokens learning in the two stages
  937. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  938. network = self.network
  939. else:
  940. network = self.network[:max_index + 1]
  941. for idx, block in enumerate(network):
  942. if idx == 2:
  943. # add positional encoding after outlooker blocks
  944. x = x + self.pos_embed
  945. x = self.pos_drop(x)
  946. if self.grad_checkpointing and not torch.jit.is_scripting():
  947. x = checkpoint(block, x)
  948. else:
  949. x = block(x)
  950. if idx in take_indices:
  951. if norm and idx >= 2:
  952. x_inter = self.norm(x)
  953. else:
  954. x_inter = x
  955. intermediates.append(x_inter.permute(0, 3, 1, 2))
  956. if intermediates_only:
  957. return intermediates
  958. # NOTE not supporting return of class tokens
  959. # step3: post network, apply class attention or not
  960. B, H, W, C = x.shape
  961. x = x.reshape(B, -1, C)
  962. if self.post_network is not None:
  963. x = self.forward_cls(x)
  964. x = self.norm(x)
  965. return x, intermediates
  966. def prune_intermediate_layers(
  967. self,
  968. indices: Union[int, List[int]] = 1,
  969. prune_norm: bool = False,
  970. prune_head: bool = True,
  971. ) -> List[int]:
  972. """Prune layers not required for specified intermediates.
  973. Args:
  974. indices: Indices of intermediate layers to keep.
  975. prune_norm: Whether to prune normalization layer.
  976. prune_head: Whether to prune classification head.
  977. Returns:
  978. List of kept intermediate indices.
  979. """
  980. """ Prune layers not required for specified intermediates.
  981. """
  982. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  983. max_index = self.stage_ends[max_index]
  984. self.network = self.network[:max_index + 1] # truncate blocks
  985. if prune_norm:
  986. self.norm = nn.Identity()
  987. if prune_head:
  988. self.post_network = nn.ModuleList() # prune token blocks with head
  989. self.reset_classifier(0, '')
  990. return take_indices
  991. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  992. """Forward pass through feature extraction.
  993. Args:
  994. x: Input tensor of shape (B, C, H, W).
  995. Returns:
  996. Feature tensor.
  997. """
  998. x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
  999. # step2: tokens learning in the two stages
  1000. x = self.forward_tokens(x)
  1001. # step3: post network, apply class attention or not
  1002. if self.post_network is not None:
  1003. x = self.forward_cls(x)
  1004. x = self.norm(x)
  1005. return x
  1006. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  1007. """Forward pass through classification head.
  1008. Args:
  1009. x: Input feature tensor.
  1010. pre_logits: Whether to return pre-logits features.
  1011. Returns:
  1012. Classification logits or pre-logits features.
  1013. """
  1014. if self.global_pool == 'avg':
  1015. out = x.mean(dim=1)
  1016. elif self.global_pool == 'token':
  1017. out = x[:, 0]
  1018. else:
  1019. out = x
  1020. x = self.head_drop(x)
  1021. if pre_logits:
  1022. return out
  1023. out = self.head(out)
  1024. if self.aux_head is not None:
  1025. # generate classes in all feature tokens, see token labeling
  1026. aux = self.aux_head(x[:, 1:])
  1027. out = out + 0.5 * aux.max(1)[0]
  1028. return out
  1029. def forward(self, x: torch.Tensor) -> torch.Tensor:
  1030. """Forward pass (simplified, without mix token training).
  1031. Args:
  1032. x: Input tensor of shape (B, C, H, W).
  1033. Returns:
  1034. Classification logits.
  1035. """
  1036. """ simplified forward (without mix token training) """
  1037. x = self.forward_features(x)
  1038. x = self.forward_head(x)
  1039. return x
  1040. def _create_volo(variant: str, pretrained: bool = False, **kwargs: Any) -> VOLO:
  1041. """Create VOLO model.
  1042. Args:
  1043. variant: Model variant name.
  1044. pretrained: Whether to load pretrained weights.
  1045. **kwargs: Additional model arguments.
  1046. Returns:
  1047. VOLO model instance.
  1048. """
  1049. out_indices = kwargs.pop('out_indices', 3)
  1050. return build_model_with_cfg(
  1051. VOLO,
  1052. variant,
  1053. pretrained,
  1054. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  1055. **kwargs,
  1056. )
  1057. def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
  1058. """Create model configuration.
  1059. Args:
  1060. url: URL for pretrained weights.
  1061. **kwargs: Additional configuration options.
  1062. Returns:
  1063. Model configuration dictionary.
  1064. """
  1065. return {
  1066. 'url': url,
  1067. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  1068. 'crop_pct': .96, 'interpolation': 'bicubic', 'fixed_input_size': True,
  1069. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  1070. 'first_conv': 'patch_embed.conv.0', 'classifier': ('head', 'aux_head'),
  1071. 'license': 'apache-2.0',
  1072. **kwargs
  1073. }
  1074. default_cfgs = generate_default_cfgs({
  1075. 'volo_d1_224.sail_in1k': _cfg(
  1076. hf_hub_id='timm/',
  1077. url='https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar',
  1078. crop_pct=0.96),
  1079. 'volo_d1_384.sail_in1k': _cfg(
  1080. hf_hub_id='timm/',
  1081. url='https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar',
  1082. crop_pct=1.0, input_size=(3, 384, 384)),
  1083. 'volo_d2_224.sail_in1k': _cfg(
  1084. hf_hub_id='timm/',
  1085. url='https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar',
  1086. crop_pct=0.96),
  1087. 'volo_d2_384.sail_in1k': _cfg(
  1088. hf_hub_id='timm/',
  1089. url='https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar',
  1090. crop_pct=1.0, input_size=(3, 384, 384)),
  1091. 'volo_d3_224.sail_in1k': _cfg(
  1092. hf_hub_id='timm/',
  1093. url='https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar',
  1094. crop_pct=0.96),
  1095. 'volo_d3_448.sail_in1k': _cfg(
  1096. hf_hub_id='timm/',
  1097. url='https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar',
  1098. crop_pct=1.0, input_size=(3, 448, 448)),
  1099. 'volo_d4_224.sail_in1k': _cfg(
  1100. hf_hub_id='timm/',
  1101. url='https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar',
  1102. crop_pct=0.96),
  1103. 'volo_d4_448.sail_in1k': _cfg(
  1104. hf_hub_id='timm/',
  1105. url='https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar',
  1106. crop_pct=1.15, input_size=(3, 448, 448)),
  1107. 'volo_d5_224.sail_in1k': _cfg(
  1108. hf_hub_id='timm/',
  1109. url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar',
  1110. crop_pct=0.96),
  1111. 'volo_d5_448.sail_in1k': _cfg(
  1112. hf_hub_id='timm/',
  1113. url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar',
  1114. crop_pct=1.15, input_size=(3, 448, 448)),
  1115. 'volo_d5_512.sail_in1k': _cfg(
  1116. hf_hub_id='timm/',
  1117. url='https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar',
  1118. crop_pct=1.15, input_size=(3, 512, 512)),
  1119. })
  1120. @register_model
  1121. def volo_d1_224(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1122. """VOLO-D1 model, Params: 27M."""
  1123. model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
  1124. model = _create_volo('volo_d1_224', pretrained=pretrained, **model_args)
  1125. return model
  1126. @register_model
  1127. def volo_d1_384(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1128. """VOLO-D1 model, Params: 27M."""
  1129. model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
  1130. model = _create_volo('volo_d1_384', pretrained=pretrained, **model_args)
  1131. return model
  1132. @register_model
  1133. def volo_d2_224(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1134. """VOLO-D2 model, Params: 59M."""
  1135. model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
  1136. model = _create_volo('volo_d2_224', pretrained=pretrained, **model_args)
  1137. return model
  1138. @register_model
  1139. def volo_d2_384(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1140. """VOLO-D2 model, Params: 59M."""
  1141. model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
  1142. model = _create_volo('volo_d2_384', pretrained=pretrained, **model_args)
  1143. return model
  1144. @register_model
  1145. def volo_d3_224(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1146. """VOLO-D3 model, Params: 86M."""
  1147. model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
  1148. model = _create_volo('volo_d3_224', pretrained=pretrained, **model_args)
  1149. return model
  1150. @register_model
  1151. def volo_d3_448(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1152. """VOLO-D3 model, Params: 86M."""
  1153. model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
  1154. model = _create_volo('volo_d3_448', pretrained=pretrained, **model_args)
  1155. return model
  1156. @register_model
  1157. def volo_d4_224(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1158. """VOLO-D4 model, Params: 193M."""
  1159. model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
  1160. model = _create_volo('volo_d4_224', pretrained=pretrained, **model_args)
  1161. return model
  1162. @register_model
  1163. def volo_d4_448(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1164. """VOLO-D4 model, Params: 193M."""
  1165. model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
  1166. model = _create_volo('volo_d4_448', pretrained=pretrained, **model_args)
  1167. return model
  1168. @register_model
  1169. def volo_d5_224(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1170. """VOLO-D5 model, Params: 296M.
  1171. stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5.
  1172. """
  1173. model_args = dict(
  1174. layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16),
  1175. mlp_ratio=4, stem_hidden_dim=128, **kwargs)
  1176. model = _create_volo('volo_d5_224', pretrained=pretrained, **model_args)
  1177. return model
  1178. @register_model
  1179. def volo_d5_448(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1180. """VOLO-D5 model, Params: 296M.
  1181. stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5.
  1182. """
  1183. model_args = dict(
  1184. layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16),
  1185. mlp_ratio=4, stem_hidden_dim=128, **kwargs)
  1186. model = _create_volo('volo_d5_448', pretrained=pretrained, **model_args)
  1187. return model
  1188. @register_model
  1189. def volo_d5_512(pretrained: bool = False, **kwargs: Any) -> VOLO:
  1190. """VOLO-D5 model, Params: 296M.
  1191. stem_hidden_dim=128, the dim in patch embedding is 128 for VOLO-D5.
  1192. """
  1193. model_args = dict(
  1194. layers=(12, 12, 20, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16),
  1195. mlp_ratio=4, stem_hidden_dim=128, **kwargs)
  1196. model = _create_volo('volo_d5_512', pretrained=pretrained, **model_args)
  1197. return model