hieradet_sam2.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699
  1. import math
  2. from copy import deepcopy
  3. from functools import partial
  4. from typing import Dict, List, Optional, Tuple, Type, Union
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.functional as F
  8. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  9. from timm.layers import (
  10. PatchEmbed,
  11. Mlp,
  12. DropPath,
  13. calculate_drop_path_rates,
  14. ClNormMlpClassifierHead,
  15. LayerScale,
  16. get_norm_layer,
  17. get_act_layer,
  18. init_weight_jax,
  19. init_weight_vit,
  20. to_2tuple,
  21. use_fused_attn,
  22. )
  23. from ._builder import build_model_with_cfg
  24. from ._features import feature_take_indices
  25. from ._manipulate import named_apply, checkpoint
  26. from ._registry import generate_default_cfgs, register_model
  27. def window_partition(x, window_size: Tuple[int, int]):
  28. """
  29. Partition into non-overlapping windows with padding if needed.
  30. Args:
  31. x (tensor): input tokens with [B, H, W, C].
  32. window_size (int): window size.
  33. Returns:
  34. windows: windows after partition with [B * num_windows, window_size, window_size, C].
  35. (Hp, Wp): padded height and width before partition
  36. """
  37. B, H, W, C = x.shape
  38. x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
  39. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
  40. return windows
  41. def window_unpartition(windows: torch.Tensor, window_size: Tuple[int, int], hw: Tuple[int, int]):
  42. """
  43. Window unpartition into original sequences and removing padding.
  44. Args:
  45. x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
  46. window_size (int): window size.
  47. hw (Tuple): original height and width (H, W) before padding.
  48. Returns:
  49. x: unpartitioned sequences with [B, H, W, C].
  50. """
  51. H, W = hw
  52. B = windows.shape[0] // (H * W // window_size[0] // window_size[1])
  53. x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
  54. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
  55. return x
  56. def _calc_pad(H: int, W: int, window_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
  57. pad_h = (window_size[0] - H % window_size[0]) % window_size[0]
  58. pad_w = (window_size[1] - W % window_size[1]) % window_size[1]
  59. Hp, Wp = H + pad_h, W + pad_w
  60. return Hp, Wp, pad_h, pad_w
  61. class MultiScaleAttention(nn.Module):
  62. fused_attn: torch.jit.Final[bool]
  63. def __init__(
  64. self,
  65. dim: int,
  66. dim_out: int,
  67. num_heads: int,
  68. q_pool: nn.Module = None,
  69. device=None,
  70. dtype=None,
  71. ):
  72. dd = {'device': device, 'dtype': dtype}
  73. super().__init__()
  74. self.dim = dim
  75. self.dim_out = dim_out
  76. self.num_heads = num_heads
  77. head_dim = dim_out // num_heads
  78. self.scale = head_dim ** -0.5
  79. self.fused_attn = use_fused_attn()
  80. self.q_pool = q_pool
  81. self.qkv = nn.Linear(dim, dim_out * 3, **dd)
  82. self.proj = nn.Linear(dim_out, dim_out, **dd)
  83. def forward(self, x: torch.Tensor) -> torch.Tensor:
  84. B, H, W, _ = x.shape
  85. # qkv with shape (B, H * W, 3, nHead, C)
  86. qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
  87. # q, k, v with shape (B, H * W, nheads, C)
  88. q, k, v = torch.unbind(qkv, 2)
  89. # Q pooling (for downsample at stage changes)
  90. if self.q_pool is not None:
  91. q = q.reshape(B, H, W, -1).permute(0, 3, 1, 2) # to BCHW for pool
  92. q = self.q_pool(q).permute(0, 2, 3, 1)
  93. H, W = q.shape[1:3] # downsampled shape
  94. q = q.reshape(B, H * W, self.num_heads, -1)
  95. # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
  96. q = q.transpose(1, 2)
  97. k = k.transpose(1, 2)
  98. v = v.transpose(1, 2)
  99. if self.fused_attn:
  100. x = F.scaled_dot_product_attention(q, k, v)
  101. else:
  102. q = q * self.scale
  103. attn = q @ k.transpose(-1, -2)
  104. attn = attn.softmax(dim=-1)
  105. x = attn @ v
  106. # Transpose back
  107. x = x.transpose(1, 2).reshape(B, H, W, -1)
  108. x = self.proj(x)
  109. return x
  110. class MultiScaleBlock(nn.Module):
  111. def __init__(
  112. self,
  113. dim: int,
  114. dim_out: int,
  115. num_heads: int,
  116. mlp_ratio: float = 4.0,
  117. q_stride: Optional[Tuple[int, int]] = None,
  118. norm_layer: Union[Type[nn.Module], str] = "LayerNorm",
  119. act_layer: Union[Type[nn.Module], str] = "GELU",
  120. window_size: int = 0,
  121. init_values: Optional[float] = None,
  122. drop_path: float = 0.0,
  123. device=None,
  124. dtype=None,
  125. ):
  126. dd = {'device': device, 'dtype': dtype}
  127. super().__init__()
  128. norm_layer = get_norm_layer(norm_layer)
  129. act_layer = get_act_layer(act_layer)
  130. self.window_size = to_2tuple(window_size)
  131. self.is_windowed = any(self.window_size)
  132. self.dim = dim
  133. self.dim_out = dim_out
  134. self.q_stride = q_stride
  135. if dim != dim_out:
  136. self.proj = nn.Linear(dim, dim_out, **dd)
  137. else:
  138. self.proj = nn.Identity()
  139. self.pool = None
  140. if self.q_stride:
  141. # note make a different instance for this Module so that it's not shared with attn module
  142. self.pool = nn.MaxPool2d(
  143. kernel_size=q_stride,
  144. stride=q_stride,
  145. ceil_mode=False,
  146. )
  147. self.norm1 = norm_layer(dim, **dd)
  148. self.attn = MultiScaleAttention(
  149. dim,
  150. dim_out,
  151. num_heads=num_heads,
  152. q_pool=deepcopy(self.pool),
  153. **dd,
  154. )
  155. self.ls1 = LayerScale(dim_out, init_values, **dd) if init_values is not None else nn.Identity()
  156. self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  157. self.norm2 = norm_layer(dim_out, **dd)
  158. self.mlp = Mlp(
  159. dim_out,
  160. int(dim_out * mlp_ratio),
  161. act_layer=act_layer,
  162. **dd,
  163. )
  164. self.ls2 = LayerScale(dim_out, init_values, **dd) if init_values is not None else nn.Identity()
  165. self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  166. def forward(self, x: torch.Tensor) -> torch.Tensor:
  167. shortcut = x # B, H, W, C
  168. x = self.norm1(x)
  169. # Skip connection
  170. if self.dim != self.dim_out:
  171. shortcut = self.proj(x)
  172. if self.pool is not None:
  173. shortcut = shortcut.permute(0, 3, 1, 2)
  174. shortcut = self.pool(shortcut).permute(0, 2, 3, 1)
  175. # Window partition
  176. window_size = self.window_size
  177. H, W = x.shape[1:3]
  178. Hp, Wp = H, W # keep torchscript happy
  179. if self.is_windowed:
  180. Hp, Wp, pad_h, pad_w = _calc_pad(H, W, window_size)
  181. x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
  182. x = window_partition(x, window_size)
  183. # Window Attention + Q Pooling (if stage change)
  184. x = self.attn(x)
  185. if self.q_stride is not None:
  186. # Shapes have changed due to Q pooling
  187. window_size = (self.window_size[0] // self.q_stride[0], self.window_size[1] // self.q_stride[1])
  188. H, W = shortcut.shape[1:3]
  189. Hp, Wp, pad_h, pad_w = _calc_pad(H, W, window_size)
  190. # Reverse window partition
  191. if self.is_windowed:
  192. x = window_unpartition(x, window_size, (Hp, Wp))
  193. x = x[:, :H, :W, :].contiguous() # unpad
  194. x = shortcut + self.drop_path1(self.ls1(x))
  195. x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
  196. return x
  197. class HieraPatchEmbed(nn.Module):
  198. """
  199. Image to Patch Embedding.
  200. """
  201. def __init__(
  202. self,
  203. kernel_size: Union[int, Tuple[int, int]] = (7, 7),
  204. stride: Union[int, Tuple[int, int]] = (4, 4),
  205. padding: Union[str, int, Tuple[int, int]] = (3, 3),
  206. in_chans: int = 3,
  207. embed_dim: int = 768,
  208. device=None,
  209. dtype=None,
  210. ):
  211. """
  212. Args:
  213. kernel_size: kernel size of the projection layer.
  214. stride: stride of the projection layer.
  215. padding: padding size of the projection layer.
  216. in_chans: Number of input image channels.
  217. embed_dim: Patch embedding dimension.
  218. """
  219. super().__init__()
  220. dd = {'device': device, 'dtype': dtype}
  221. self.proj = nn.Conv2d(
  222. in_chans,
  223. embed_dim,
  224. kernel_size=kernel_size,
  225. stride=stride,
  226. padding=padding,
  227. **dd,
  228. )
  229. def forward(self, x: torch.Tensor) -> torch.Tensor:
  230. x = self.proj(x)
  231. # B C H W -> B H W C
  232. x = x.permute(0, 2, 3, 1)
  233. return x
  234. class HieraDet(nn.Module):
  235. """
  236. Reference: https://arxiv.org/abs/2306.00989
  237. """
  238. def __init__(
  239. self,
  240. in_chans: int = 3,
  241. num_classes: int = 1000,
  242. global_pool: str = 'avg',
  243. embed_dim: int = 96, # initial embed dim
  244. num_heads: int = 1, # initial number of heads
  245. patch_kernel: Tuple[int, int] = (7, 7),
  246. patch_stride: Tuple[int, int] = (4, 4),
  247. patch_padding: Tuple[int, int] = (3, 3),
  248. patch_size: Optional[Tuple[int, int]] = None,
  249. q_pool: int = 3, # number of q_pool stages
  250. q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
  251. stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
  252. dim_mul: float = 2.0, # dim_mul factor at stage shift
  253. head_mul: float = 2.0, # head_mul factor at stage shift
  254. global_pos_size: Tuple[int, int] = (7, 7),
  255. # window size per stage, when not using global att.
  256. window_spec: Tuple[int, ...] = (
  257. 8,
  258. 4,
  259. 14,
  260. 7,
  261. ),
  262. # global attn in these blocks
  263. global_att_blocks: Tuple[int, ...] = (
  264. 12,
  265. 16,
  266. 20,
  267. ),
  268. init_values: Optional[float] = None,
  269. weight_init: str = '',
  270. fix_init: bool = True,
  271. head_init_scale: float = 0.001,
  272. drop_rate: float = 0.0,
  273. drop_path_rate: float = 0.0, # stochastic depth
  274. norm_layer: Union[Type[nn.Module], str] = "LayerNorm",
  275. act_layer: Union[Type[nn.Module], str] = "GELU",
  276. device=None,
  277. dtype=None,
  278. ):
  279. super().__init__()
  280. dd = {'device': device, 'dtype': dtype}
  281. norm_layer = get_norm_layer(norm_layer)
  282. act_layer = get_act_layer(act_layer)
  283. assert len(stages) == len(window_spec)
  284. self.grad_checkpointing = False
  285. self.num_classes = num_classes
  286. self.window_spec = window_spec
  287. self.output_fmt = 'NHWC'
  288. depth = sum(stages)
  289. self.q_stride = q_stride
  290. self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
  291. assert 0 <= q_pool <= len(self.stage_ends[:-1])
  292. self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
  293. if patch_size is not None:
  294. # use a non-overlapping vit style patch embed
  295. self.patch_embed = PatchEmbed(
  296. img_size=None,
  297. patch_size=patch_size,
  298. in_chans=in_chans,
  299. embed_dim=embed_dim,
  300. output_fmt='NHWC',
  301. dynamic_img_pad=True,
  302. **dd,
  303. )
  304. else:
  305. self.patch_embed = HieraPatchEmbed(
  306. kernel_size=patch_kernel,
  307. stride=patch_stride,
  308. padding=patch_padding,
  309. in_chans=in_chans,
  310. embed_dim=embed_dim,
  311. **dd,
  312. )
  313. # Which blocks have global att?
  314. self.global_att_blocks = global_att_blocks
  315. # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
  316. self.global_pos_size = global_pos_size
  317. self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.global_pos_size, **dd))
  318. self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0], **dd))
  319. dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
  320. cur_stage = 0
  321. self.blocks = nn.Sequential()
  322. self.feature_info = []
  323. for i in range(depth):
  324. dim_out = embed_dim
  325. # lags by a block, so first block of
  326. # next stage uses an initial window size
  327. # of previous stage and final window size of current stage
  328. window_size = self.window_spec[cur_stage]
  329. if self.global_att_blocks is not None:
  330. window_size = 0 if i in self.global_att_blocks else window_size
  331. if i - 1 in self.stage_ends:
  332. dim_out = int(embed_dim * dim_mul)
  333. num_heads = int(num_heads * head_mul)
  334. cur_stage += 1
  335. block = MultiScaleBlock(
  336. dim=embed_dim,
  337. dim_out=dim_out,
  338. num_heads=num_heads,
  339. drop_path=dpr[i],
  340. q_stride=self.q_stride if i in self.q_pool_blocks else None,
  341. window_size=window_size,
  342. norm_layer=norm_layer,
  343. act_layer=act_layer,
  344. init_values=init_values,
  345. **dd,
  346. )
  347. embed_dim = dim_out
  348. self.blocks.append(block)
  349. if i in self.stage_ends:
  350. self.feature_info += [
  351. dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')]
  352. self.num_features = self.head_hidden_size = embed_dim
  353. self.head = ClNormMlpClassifierHead(
  354. embed_dim,
  355. num_classes,
  356. pool_type=global_pool,
  357. drop_rate=drop_rate,
  358. norm_layer=norm_layer,
  359. **dd,
  360. )
  361. # Initialize everything
  362. if self.pos_embed is not None:
  363. nn.init.trunc_normal_(self.pos_embed, std=0.02)
  364. if self.pos_embed_window is not None:
  365. nn.init.trunc_normal_(self.pos_embed_window, std=0.02)
  366. if weight_init != 'skip':
  367. init_fn = init_weight_jax if weight_init == 'jax' else init_weight_vit
  368. init_fn = partial(init_fn, classifier_name='head.fc')
  369. named_apply(init_fn, self)
  370. if fix_init:
  371. self.fix_init_weight()
  372. if isinstance(self.head, ClNormMlpClassifierHead) and isinstance(self.head.fc, nn.Linear):
  373. self.head.fc.weight.data.mul_(head_init_scale)
  374. self.head.fc.bias.data.mul_(head_init_scale)
  375. def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
  376. h, w = x.shape[1:3]
  377. window_embed = self.pos_embed_window
  378. pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
  379. tile_h = pos_embed.shape[-2] // window_embed.shape[-2]
  380. tile_w = pos_embed.shape[-1] // window_embed.shape[-1]
  381. pos_embed = pos_embed + window_embed.tile((tile_h, tile_w))
  382. pos_embed = pos_embed.permute(0, 2, 3, 1)
  383. return x + pos_embed
  384. def fix_init_weight(self):
  385. def rescale(param, _layer_id):
  386. param.div_(math.sqrt(2.0 * _layer_id))
  387. for layer_id, layer in enumerate(self.blocks):
  388. rescale(layer.attn.proj.weight.data, layer_id + 1)
  389. rescale(layer.mlp.fc2.weight.data, layer_id + 1)
  390. @torch.jit.ignore
  391. def no_weight_decay(self):
  392. return ['pos_embed', 'pos_embed_window']
  393. @torch.jit.ignore
  394. def group_matcher(self, coarse: bool = False) -> Dict:
  395. return dict(
  396. stem=r'^pos_embed|pos_embed_window|patch_embed',
  397. blocks=[(r'^blocks\.(\d+)', None)]
  398. )
  399. @torch.jit.ignore
  400. def set_grad_checkpointing(self, enable: bool = True) -> None:
  401. self.grad_checkpointing = enable
  402. @torch.jit.ignore
  403. def get_classifier(self):
  404. return self.head.fc
  405. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, reset_other: bool = False):
  406. self.num_classes = num_classes
  407. self.head.reset(num_classes, pool_type=global_pool, reset_other=reset_other)
  408. def forward_intermediates(
  409. self,
  410. x: torch.Tensor,
  411. indices: Optional[Union[int, List[int]]] = None,
  412. norm: bool = False,
  413. stop_early: bool = True,
  414. output_fmt: str = 'NCHW',
  415. intermediates_only: bool = False,
  416. coarse: bool = True,
  417. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  418. """ Forward features that returns intermediates.
  419. Args:
  420. x: Input image tensor
  421. indices: Take last n blocks if int, all if None, select matching indices if sequence
  422. norm: Apply norm layer to all intermediates
  423. stop_early: Stop iterating over blocks when last desired intermediate hit
  424. output_fmt: Shape of intermediate feature outputs
  425. intermediates_only: Only return intermediate features
  426. coarse: Take coarse features (stage ends) if true, otherwise all block featrures
  427. Returns:
  428. """
  429. assert not norm, 'normalization of features not supported'
  430. assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW, NHWC.'
  431. if coarse:
  432. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  433. take_indices = [self.stage_ends[i] for i in take_indices]
  434. max_index = self.stage_ends[max_index]
  435. else:
  436. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  437. x = self.patch_embed(x)
  438. x = self._pos_embed(x)
  439. intermediates = []
  440. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  441. blocks = self.blocks
  442. else:
  443. blocks = self.blocks[:max_index + 1]
  444. for i, blk in enumerate(blocks):
  445. if self.grad_checkpointing and not torch.jit.is_scripting():
  446. x = checkpoint(blk, x)
  447. else:
  448. x = blk(x)
  449. if i in take_indices:
  450. x_out = x.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x
  451. intermediates.append(x_out)
  452. if intermediates_only:
  453. return intermediates
  454. return x, intermediates
  455. def prune_intermediate_layers(
  456. self,
  457. indices: Union[int, List[int]] = 1,
  458. prune_norm: bool = False,
  459. prune_head: bool = True,
  460. coarse: bool = True,
  461. ):
  462. """ Prune layers not required for specified intermediates.
  463. """
  464. if coarse:
  465. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  466. max_index = self.stage_ends[max_index]
  467. else:
  468. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  469. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  470. if prune_head:
  471. self.head.reset(0, reset_other=prune_norm)
  472. return take_indices
  473. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  474. x = self.patch_embed(x) # BHWC
  475. x = self._pos_embed(x)
  476. for blk in self.blocks:
  477. if self.grad_checkpointing and not torch.jit.is_scripting():
  478. x = checkpoint(blk, x)
  479. else:
  480. x = blk(x)
  481. return x
  482. def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
  483. x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  484. return x
  485. def forward(self, x: torch.Tensor) -> torch.Tensor:
  486. x = self.forward_features(x)
  487. x = self.forward_head(x)
  488. return x
  489. # NOTE sam2 appears to use 1024x1024 for all models, but T, S, & B+ have windows that fit multiples of 224.
  490. def _cfg(url='', **kwargs):
  491. return {
  492. 'url': url,
  493. 'num_classes': 0, 'input_size': (3, 896, 896), 'pool_size': (28, 28),
  494. 'crop_pct': 1.0, 'interpolation': 'bicubic', 'min_input_size': (3, 224, 224),
  495. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  496. 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
  497. 'license': 'apache-2.0',
  498. **kwargs
  499. }
  500. default_cfgs = generate_default_cfgs({
  501. "sam2_hiera_tiny.fb_r896": _cfg(
  502. # hf_hub_id='facebook/sam2-hiera-tiny',
  503. # hf_hub_filename='sam2_hiera_tiny.pt',
  504. hf_hub_id='timm/',
  505. ),
  506. "sam2_hiera_tiny.fb_r896_2pt1": _cfg(
  507. # hf_hub_id='facebook/sam2.1-hiera-tiny',
  508. # hf_hub_filename='sam2.1_hiera_tiny.pt',
  509. hf_hub_id='timm/',
  510. ),
  511. "sam2_hiera_small.fb_r896": _cfg(
  512. # hf_hub_id='facebook/sam2-hiera-small',
  513. # hf_hub_filename='sam2_hiera_small.pt',
  514. hf_hub_id='timm/',
  515. ),
  516. "sam2_hiera_small.fb_r896_2pt1": _cfg(
  517. # hf_hub_id='facebook/sam2.1-hiera-small',
  518. # hf_hub_filename='sam2.1_hiera_small.pt',
  519. hf_hub_id='timm/',
  520. ),
  521. "sam2_hiera_base_plus.fb_r896": _cfg(
  522. # hf_hub_id='facebook/sam2-hiera-base-plus',
  523. # hf_hub_filename='sam2_hiera_base_plus.pt',
  524. hf_hub_id='timm/',
  525. ),
  526. "sam2_hiera_base_plus.fb_r896_2pt1": _cfg(
  527. # hf_hub_id='facebook/sam2.1-hiera-base-plus',
  528. # hf_hub_filename='sam2.1_hiera_base_plus.pt',
  529. hf_hub_id='timm/',
  530. ),
  531. "sam2_hiera_large.fb_r1024": _cfg(
  532. # hf_hub_id='facebook/sam2-hiera-large',
  533. # hf_hub_filename='sam2_hiera_large.pt',
  534. hf_hub_id='timm/',
  535. min_input_size=(3, 256, 256),
  536. input_size=(3, 1024, 1024), pool_size=(32, 32),
  537. ),
  538. "sam2_hiera_large.fb_r1024_2pt1": _cfg(
  539. # hf_hub_id='facebook/sam2.1-hiera-large',
  540. # hf_hub_filename='sam2.1_hiera_large.pt',
  541. hf_hub_id='timm/',
  542. min_input_size=(3, 256, 256),
  543. input_size=(3, 1024, 1024), pool_size=(32, 32),
  544. ),
  545. "hieradet_small.untrained": _cfg(
  546. num_classes=1000,
  547. input_size=(3, 256, 256), pool_size=(8, 8),
  548. ),
  549. })
  550. def checkpoint_filter_fn(state_dict, model=None, prefix=''):
  551. state_dict = state_dict.get('model', state_dict)
  552. output = {}
  553. for k, v in state_dict.items():
  554. if k.startswith(prefix):
  555. k = k.replace(prefix, '')
  556. else:
  557. continue
  558. k = k.replace('mlp.layers.0', 'mlp.fc1')
  559. k = k.replace('mlp.layers.1', 'mlp.fc2')
  560. output[k] = v
  561. return output
  562. def _create_hiera_det(variant: str, pretrained: bool = False, **kwargs) -> HieraDet:
  563. out_indices = kwargs.pop('out_indices', 4)
  564. checkpoint_prefix = ''
  565. # if 'sam2' in variant:
  566. # # SAM2 pretrained weights have no classifier or final norm-layer (`head.norm`)
  567. # # This is workaround loading with num_classes=0 w/o removing norm-layer.
  568. # kwargs.setdefault('pretrained_strict', False)
  569. # checkpoint_prefix = 'image_encoder.trunk.'
  570. return build_model_with_cfg(
  571. HieraDet,
  572. variant,
  573. pretrained,
  574. pretrained_filter_fn=partial(checkpoint_filter_fn, prefix=checkpoint_prefix),
  575. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  576. **kwargs,
  577. )
  578. @register_model
  579. def sam2_hiera_tiny(pretrained=False, **kwargs):
  580. model_args = dict(stages=(1, 2, 7, 2), global_att_blocks=(5, 7, 9))
  581. return _create_hiera_det('sam2_hiera_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
  582. @register_model
  583. def sam2_hiera_small(pretrained=False, **kwargs):
  584. model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13))
  585. return _create_hiera_det('sam2_hiera_small', pretrained=pretrained, **dict(model_args, **kwargs))
  586. @register_model
  587. def sam2_hiera_base_plus(pretrained=False, **kwargs):
  588. model_args = dict(embed_dim=112, num_heads=2, global_pos_size=(14, 14))
  589. return _create_hiera_det('sam2_hiera_base_plus', pretrained=pretrained, **dict(model_args, **kwargs))
  590. @register_model
  591. def sam2_hiera_large(pretrained=False, **kwargs):
  592. model_args = dict(
  593. embed_dim=144,
  594. num_heads=2,
  595. stages=(2, 6, 36, 4),
  596. global_att_blocks=(23, 33, 43),
  597. window_spec=(8, 4, 16, 8),
  598. )
  599. return _create_hiera_det('sam2_hiera_large', pretrained=pretrained, **dict(model_args, **kwargs))
  600. @register_model
  601. def hieradet_small(pretrained=False, **kwargs):
  602. model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13), window_spec=(8, 4, 16, 8), init_values=1e-5)
  603. return _create_hiera_det('hieradet_small', pretrained=pretrained, **dict(model_args, **kwargs))
  604. # @register_model
  605. # def hieradet_base(pretrained=False, **kwargs):
  606. # model_args = dict(window_spec=(8, 4, 16, 8))
  607. # return _create_hiera_det('hieradet_base', pretrained=pretrained, **dict(model_args, **kwargs))