hiera.py 37 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025
  1. """ An PyTorch implementation of Hiera
  2. Adapted for timm from originals at https://github.com/facebookresearch/hiera
  3. """
  4. # Copyright (c) Meta Platforms, Inc. and affiliates.
  5. # All rights reserved.
  6. # This source code is licensed under the license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. # --------------------------------------------------------
  9. #
  10. # Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles
  11. #
  12. # Chaitanya Ryali, Yuan-Ting Hu, Daniel Bolya, Chen Wei, Haoqi Fan,
  13. # Po-Yao Huang, Vaibhav Aggarwal, Arkabandhu Chowdhury, Omid Poursaeed,
  14. # Judy Hoffman, Jitendra Malik, Yanghao Li, Christoph Feichtenhofer.
  15. #
  16. # Paper: https://arxiv.org/abs/2306.00989/
  17. #
  18. # References:
  19. # slowfast: https://github.com/facebookresearch/SlowFast
  20. # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
  21. # --------------------------------------------------------
  22. import math
  23. from functools import partial
  24. from typing import Dict, List, Optional, Tuple, Type, Union
  25. import torch
  26. import torch.nn as nn
  27. import torch.nn.functional as F
  28. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  29. from timm.layers import (
  30. DropPath,
  31. calculate_drop_path_rates,
  32. Mlp,
  33. LayerScale,
  34. ClNormMlpClassifierHead,
  35. use_fused_attn,
  36. _assert,
  37. get_norm_layer,
  38. to_2tuple,
  39. init_weight_vit,
  40. init_weight_jax,
  41. )
  42. from ._registry import generate_default_cfgs, register_model
  43. from ._builder import build_model_with_cfg
  44. from ._features import feature_take_indices
  45. from ._features_fx import register_notrace_function
  46. from ._manipulate import named_apply, checkpoint
  47. __all__ = ['Hiera']
  48. def conv_nd(n: int) -> Type[nn.Module]:
  49. """
  50. Returns a conv with nd (e.g., Conv2d for n=2). Work up to n=3.
  51. If you wanted a 4d Hiera, you could probably just implement this for n=4. (no promises)
  52. """
  53. return [nn.Identity, nn.Conv1d, nn.Conv2d, nn.Conv3d][n]
  54. @register_notrace_function
  55. def get_resized_mask(target_size: List[int], mask: torch.Tensor) -> torch.Tensor:
  56. # target_size: [(T), (H), W]
  57. # (spatial) mask: [B, C, (t), (h), w]
  58. if mask is None:
  59. return mask
  60. _assert(len(mask.shape[2:]) == len(target_size), "mask spatial shape and target_size must match.")
  61. if mask.shape[2:] != target_size:
  62. return F.interpolate(mask.float(), size=target_size)
  63. return mask
  64. def undo_windowing(
  65. x: torch.Tensor,
  66. shape: List[int],
  67. mu_shape: List[int],
  68. ) -> torch.Tensor:
  69. """
  70. Restore spatial organization by undoing windowed organization of mask units.
  71. Args:
  72. x: organized by mask units windows, e.g. in 2d [B, #MUy*#MUx, MUy, MUx, C]
  73. shape: current spatial shape, if it were not organized into mask unit
  74. windows, e.g. in 2d [B, #MUy*MUy, #MUx*MUx, C].
  75. mu_shape: current mask unit shape, e.g. in 2d [MUy, MUx]
  76. Returns:
  77. x: e.g. in 2d, [B, #MUy*MUy, #MUx*MUx, C]
  78. """
  79. D = len(shape)
  80. B, C = x.shape[0], x.shape[-1]
  81. # [B, #MUy*#MUx, MUy, MUx, C] -> [B, #MUy, #MUx, MUy, MUx, C]
  82. num_MUs = [s // mu for s, mu in zip(shape, mu_shape)]
  83. x = x.view(B, *num_MUs, *mu_shape, C)
  84. # [B, #MUy, #MUx, MUy, MUx, C] -> [B, #MUy*MUy, #MUx*MUx, C]
  85. permute = (
  86. [0]
  87. + sum([list(p) for p in zip(range(1, 1 + D), range(1 + D, 1 + 2 * D))], [])
  88. + [len(x.shape) - 1]
  89. )
  90. x = x.permute(permute).reshape(B, *shape, C)
  91. return x
  92. class Unroll(nn.Module):
  93. """
  94. Reorders the tokens such that patches are contiguous in memory.
  95. E.g., given [B, (H, W), C] and stride of (Sy, Sx), this will re-order the tokens as
  96. [B, (Sy, Sx, H // Sy, W // Sx), C]
  97. This allows operations like Max2d to be computed as x.view(B, Sx*Sy, -1, C).max(dim=1).
  98. Not only is this faster, but it also makes it easy to support inputs of arbitrary
  99. dimensions in addition to patch-wise sparsity.
  100. Performing this operation multiple times in sequence puts entire windows as contiguous
  101. in memory. For instance, if you applied the stride (2, 2) 3 times, entire windows of
  102. size 8x8 would be contiguous in memory, allowing operations like mask unit attention
  103. computed easily and efficiently, while also allowing max to be applied sequentially.
  104. Note: This means that intermediate values of the model are not in HxW order, so they
  105. need to be re-rolled if you want to use the intermediate values as a HxW feature map.
  106. The last block of the network is fine though, since by then the strides are all consumed.
  107. """
  108. def __init__(
  109. self,
  110. input_size: Tuple[int, ...],
  111. patch_stride: Tuple[int, ...],
  112. unroll_schedule: List[Tuple[int, ...]],
  113. ):
  114. super().__init__()
  115. self.size = [i // s for i, s in zip(input_size, patch_stride)]
  116. self.schedule = unroll_schedule
  117. def forward(self, x: torch.Tensor) -> torch.Tensor:
  118. """
  119. Input: Flattened patch embeddings [B, N, C]
  120. Output: Patch embeddings [B, N, C] permuted such that [B, 4, N//4, C].max(1) etc. performs MaxPoolNd
  121. """
  122. B, _, C = x.shape
  123. cur_size = self.size
  124. x = x.view(*([B] + cur_size + [C]))
  125. for strides in self.schedule:
  126. # Move patches with the given strides to the batch dimension
  127. # Create a view of the tensor with the patch stride as separate dims
  128. # For example in 2d: [B, H // Sy, Sy, W // Sx, Sx, C]
  129. cur_size = [i // s for i, s in zip(cur_size, strides)]
  130. new_shape = [B] + sum([[i, s] for i, s in zip(cur_size, strides)], []) + [C]
  131. x = x.view(new_shape)
  132. # Move the patch stride into the batch dimension
  133. # For example in 2d: [B, Sy, Sx, H // Sy, W // Sx, C]
  134. L = len(new_shape)
  135. permute = [0] + list(range(2, L - 1, 2)) + list(range(1, L - 1, 2)) + [L - 1]
  136. x = x.permute(permute)
  137. # Now finally flatten the relevant dims into the batch dimension
  138. x = x.flatten(0, len(strides))
  139. B *= math.prod(strides)
  140. x = x.reshape(-1, math.prod(self.size), C)
  141. return x
  142. class Reroll(nn.Module):
  143. """
  144. Undos the "unroll" operation so that you can use intermediate features.
  145. """
  146. def __init__(
  147. self,
  148. input_size: Tuple[int, ...],
  149. patch_stride: Tuple[int, ...],
  150. unroll_schedule: List[Tuple[int, ...]],
  151. stage_ends: List[int],
  152. q_pool: int,
  153. ):
  154. super().__init__()
  155. self.size = [i // s for i, s in zip(input_size, patch_stride)]
  156. # The first stage has to reverse everything
  157. # The next stage has to reverse all but the first unroll, etc.
  158. self.schedule = {}
  159. size = self.size
  160. for i in range(stage_ends[-1] + 1):
  161. self.schedule[i] = unroll_schedule, size
  162. # schedule unchanged if no pooling at a stage end
  163. if i in stage_ends[:q_pool]:
  164. if len(unroll_schedule) > 0:
  165. size = [n // s for n, s in zip(size, unroll_schedule[0])]
  166. unroll_schedule = unroll_schedule[1:]
  167. def forward(
  168. self,
  169. x: torch.Tensor,
  170. block_idx: int,
  171. mask: torch.Tensor = None
  172. ) -> torch.Tensor:
  173. """
  174. Roll the given tensor back up to spatial order assuming it's from the given block.
  175. If no mask is provided:
  176. - Returns [B, H, W, C] for 2d, [B, T, H, W, C] for 3d, etc.
  177. If a mask is provided:
  178. - Returns [B, #MUs, MUy, MUx, C] for 2d, etc.
  179. """
  180. schedule, size = self.schedule[block_idx]
  181. B, N, C = x.shape
  182. D = len(size)
  183. cur_mu_shape = [1] * D
  184. for strides in schedule:
  185. # Extract the current patch from N
  186. x = x.view(B, *strides, N // math.prod(strides), *cur_mu_shape, C)
  187. # Move that patch into the current MU
  188. # Example in 2d: [B, Sy, Sx, N//(Sy*Sx), MUy, MUx, C] -> [B, N//(Sy*Sx), Sy, MUy, Sx, MUx, C]
  189. L = len(x.shape)
  190. permute = (
  191. [0, 1 + D]
  192. + sum([list(p) for p in zip(range(1, 1 + D), range(1 + D + 1, L - 1))], [])
  193. + [L - 1]
  194. )
  195. x = x.permute(permute)
  196. # Reshape to [B, N//(Sy*Sx), *MU, C]
  197. for i in range(D):
  198. cur_mu_shape[i] *= strides[i]
  199. x = x.reshape(B, -1, *cur_mu_shape, C)
  200. N = x.shape[1]
  201. # Current shape (e.g., 2d: [B, #MUy*#MUx, MUy, MUx, C])
  202. x = x.view(B, N, *cur_mu_shape, C)
  203. # If masked, return [B, #MUs, MUy, MUx, C]
  204. if mask is not None:
  205. return x
  206. # If not masked, we can return [B, H, W, C]
  207. x = undo_windowing(x, size, cur_mu_shape)
  208. return x
  209. class MaskUnitAttention(nn.Module):
  210. """
  211. Computes either Mask Unit or Global Attention. Also is able to perform q pooling.
  212. Note: this assumes the tokens have already been flattened and unrolled into mask units.
  213. See `Unroll` for more details.
  214. """
  215. fused_attn: torch.jit.Final[bool]
  216. def __init__(
  217. self,
  218. dim: int,
  219. dim_out: int,
  220. heads: int,
  221. q_stride: int = 1,
  222. window_size: int = 0,
  223. use_mask_unit_attn: bool = False,
  224. device=None,
  225. dtype=None,
  226. ):
  227. """
  228. Args:
  229. - dim, dim_out: The input and output feature dimensions.
  230. - heads: The number of attention heads.
  231. - q_stride: If greater than 1, pool q with this stride. The stride should be flattened (e.g., 2x2 = 4).
  232. - window_size: The current (flattened) size of a mask unit *after* pooling (if any).
  233. - use_mask_unit_attn: Use Mask Unit or Global Attention.
  234. """
  235. dd = {'device': device, 'dtype': dtype}
  236. super().__init__()
  237. self.dim = dim
  238. self.dim_out = dim_out
  239. self.heads = heads
  240. self.q_stride = q_stride
  241. self.head_dim = dim_out // heads
  242. self.scale = self.head_dim ** -0.5
  243. self.fused_attn = use_fused_attn()
  244. self.qkv = nn.Linear(dim, 3 * dim_out, **dd)
  245. self.proj = nn.Linear(dim_out, dim_out, **dd)
  246. self.window_size = window_size
  247. self.use_mask_unit_attn = use_mask_unit_attn
  248. def forward(self, x: torch.Tensor) -> torch.Tensor:
  249. """ Input should be of shape [batch, tokens, channels]. """
  250. B, N, _ = x.shape
  251. num_windows = (N // (self.q_stride * self.window_size)) if self.use_mask_unit_attn else 1
  252. qkv = self.qkv(x).reshape(B, -1, num_windows, 3, self.heads, self.head_dim).permute(3, 0, 4, 2, 1, 5)
  253. q, k, v = qkv.unbind(0)
  254. if self.q_stride > 1:
  255. # Refer to Unroll to see how this performs a maxpool-Nd
  256. q = q.view(B, self.heads, num_windows, self.q_stride, -1, self.head_dim).amax(dim=3)
  257. if self.fused_attn:
  258. # Note: the original paper did *not* use SDPA, it's a free boost!
  259. x = F.scaled_dot_product_attention(q, k, v)
  260. else:
  261. attn = (q * self.scale) @ k.transpose(-1, -2)
  262. attn = attn.softmax(dim=-1)
  263. x = attn @ v
  264. x = x.transpose(1, 3).reshape(B, -1, self.dim_out)
  265. x = self.proj(x)
  266. return x
  267. class HieraBlock(nn.Module):
  268. def __init__(
  269. self,
  270. dim: int,
  271. dim_out: int,
  272. heads: int,
  273. mlp_ratio: float = 4.0,
  274. drop_path: float = 0.0,
  275. init_values: Optional[float] = None,
  276. norm_layer: Type[nn.Module] = nn.LayerNorm,
  277. act_layer: Type[nn.Module] = nn.GELU,
  278. q_stride: int = 1,
  279. window_size: int = 0,
  280. use_expand_proj: bool = True,
  281. use_mask_unit_attn: bool = False,
  282. device=None,
  283. dtype=None,
  284. ):
  285. dd = {'device': device, 'dtype': dtype}
  286. super().__init__()
  287. self.dim = dim
  288. self.dim_out = dim_out
  289. self.norm1 = norm_layer(dim, **dd)
  290. if dim != dim_out:
  291. self.do_expand = True
  292. if use_expand_proj:
  293. self.proj = nn.Linear(dim, dim_out, **dd)
  294. else:
  295. assert dim_out == dim * 2
  296. self.proj = None
  297. else:
  298. self.do_expand = False
  299. self.proj = None
  300. self.attn = MaskUnitAttention(
  301. dim,
  302. dim_out,
  303. heads,
  304. q_stride,
  305. window_size,
  306. use_mask_unit_attn,
  307. **dd
  308. )
  309. self.ls1 = LayerScale(dim_out, init_values=init_values, **dd) if init_values is not None else nn.Identity()
  310. self.drop_path1 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
  311. self.norm2 = norm_layer(dim_out, **dd)
  312. self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer, **dd)
  313. self.ls2 = LayerScale(dim_out, init_values=init_values, **dd) if init_values is not None else nn.Identity()
  314. self.drop_path2 = DropPath(drop_path) if drop_path > 0 else nn.Identity()
  315. def forward(self, x: torch.Tensor) -> torch.Tensor:
  316. # Attention + Q Pooling
  317. x_norm = self.norm1(x)
  318. if self.do_expand:
  319. if self.proj is not None:
  320. x = self.proj(x_norm)
  321. x = x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1) # max-pool
  322. else:
  323. x = torch.cat([
  324. x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).amax(dim=1), # max-pool
  325. x.view(x.shape[0], self.attn.q_stride, -1, x.shape[-1]).mean(dim=1), # avg-pool
  326. ],
  327. dim=-1,
  328. )
  329. x = x + self.drop_path1(self.ls1(self.attn(x_norm)))
  330. # MLP
  331. x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
  332. return x
  333. class PatchEmbed(nn.Module):
  334. """Patch embed that supports any number of spatial dimensions (1d, 2d, 3d)."""
  335. def __init__(
  336. self,
  337. dim_in: int,
  338. dim_out: int,
  339. kernel: Tuple[int, ...],
  340. stride: Tuple[int, ...],
  341. padding: Tuple[int, ...],
  342. reshape: bool = True,
  343. device=None,
  344. dtype=None,
  345. ):
  346. dd = {'device': device, 'dtype': dtype}
  347. super().__init__()
  348. # Support any number of spatial dimensions
  349. self.spatial_dims = len(kernel)
  350. self.reshape = reshape
  351. self.proj = conv_nd(self.spatial_dims)(
  352. dim_in,
  353. dim_out,
  354. kernel_size=kernel,
  355. stride=stride,
  356. padding=padding,
  357. **dd,
  358. )
  359. def forward(
  360. self,
  361. x: torch.Tensor,
  362. mask: Optional[torch.Tensor] = None,
  363. ) -> torch.Tensor:
  364. if mask is not None:
  365. mask = get_resized_mask(target_size=x.shape[2:], mask=mask)
  366. x = self.proj(x * mask.to(torch.bool))
  367. else:
  368. x = self.proj(x)
  369. if self.reshape:
  370. x = x.reshape(x.shape[0], x.shape[1], -1).transpose(2, 1)
  371. return x
  372. class Hiera(nn.Module):
  373. def __init__(
  374. self,
  375. img_size: Tuple[int, ...] = (224, 224),
  376. in_chans: int = 3,
  377. embed_dim: int = 96, # initial embed dim
  378. num_heads: int = 1, # initial number of heads
  379. num_classes: int = 1000,
  380. global_pool: str = 'avg',
  381. stages: Tuple[int, ...] = (2, 3, 16, 3),
  382. q_pool: int = 3, # number of q_pool stages
  383. q_stride: Tuple[int, ...] = (2, 2),
  384. mask_unit_size: Tuple[int, ...] = (8, 8), # must divide q_stride ** (#stages-1)
  385. # mask_unit_attn: which stages use mask unit attention?
  386. mask_unit_attn: Tuple[bool, ...] = (True, True, False, False),
  387. use_expand_proj: bool = True,
  388. dim_mul: float = 2.0,
  389. head_mul: float = 2.0,
  390. patch_kernel: Tuple[int, ...] = (7, 7),
  391. patch_stride: Tuple[int, ...] = (4, 4),
  392. patch_padding: Tuple[int, ...] = (3, 3),
  393. mlp_ratio: float = 4.0,
  394. drop_path_rate: float = 0.0,
  395. init_values: Optional[float] = None,
  396. fix_init: bool = True,
  397. weight_init: str = '',
  398. norm_layer: Union[str, Type[nn.Module]] = "LayerNorm",
  399. drop_rate: float = 0.0,
  400. patch_drop_rate: float = 0.0,
  401. head_init_scale: float = 0.001,
  402. sep_pos_embed: bool = False,
  403. abs_win_pos_embed: bool = False,
  404. global_pos_size: Tuple[int, int] = (14, 14),
  405. device=None,
  406. dtype=None,
  407. ):
  408. super().__init__()
  409. dd = {'device': device, 'dtype': dtype}
  410. self.num_classes = num_classes
  411. self.grad_checkpointing = False
  412. norm_layer = get_norm_layer(norm_layer)
  413. if isinstance(img_size, int):
  414. img_size = to_2tuple(img_size)
  415. self.patch_stride = patch_stride
  416. self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)]
  417. num_tokens = math.prod(self.tokens_spatial_shape)
  418. flat_mu_size = math.prod(mask_unit_size)
  419. flat_q_stride = math.prod(q_stride)
  420. assert q_pool < len(stages)
  421. self.q_pool, self.q_stride = q_pool, q_stride
  422. self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size
  423. self.mask_spatial_shape = [i // s for i, s in zip(self.tokens_spatial_shape, self.mask_unit_size)]
  424. self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
  425. self.patch_drop_rate = patch_drop_rate
  426. self.patch_embed = PatchEmbed(
  427. in_chans,
  428. embed_dim,
  429. patch_kernel,
  430. patch_stride,
  431. patch_padding,
  432. **dd,
  433. )
  434. self.pos_embed: Optional[nn.Parameter] = None
  435. self.pos_embed_win: Optional[nn.Parameter] = None
  436. self.pos_embed_spatial: Optional[nn.Parameter] = None
  437. self.pos_embed_temporal: Optional[nn.Parameter] = None
  438. if sep_pos_embed:
  439. self.pos_embed_spatial = nn.Parameter(
  440. torch.zeros(1, self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2], embed_dim, **dd)
  441. )
  442. self.pos_embed_temporal = nn.Parameter(
  443. torch.zeros(1, self.tokens_spatial_shape[0], embed_dim, **dd)
  444. )
  445. else:
  446. if abs_win_pos_embed:
  447. # absolute win, params NCHW to make tile & interpolate more natural before add & reshape
  448. self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *global_pos_size, **dd))
  449. self.pos_embed_win = nn.Parameter(torch.zeros(1, embed_dim, *mask_unit_size, **dd))
  450. else:
  451. self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens, embed_dim, **dd))
  452. # Setup roll and reroll modules
  453. self.unroll = Unroll(
  454. img_size,
  455. patch_stride,
  456. [q_stride] * len(self.stage_ends[:-1])
  457. )
  458. self.reroll = Reroll(
  459. img_size,
  460. patch_stride,
  461. [q_stride] * len(self.stage_ends[:-1]),
  462. self.stage_ends,
  463. q_pool,
  464. )
  465. # q_pool locations
  466. q_pool_blocks = [x + 1 for x in self.stage_ends[:q_pool]]
  467. # Transformer blocks
  468. cur_stage = 0
  469. depth = sum(stages)
  470. dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
  471. self.blocks = nn.ModuleList()
  472. self.feature_info = []
  473. for i in range(depth):
  474. dim_out = embed_dim
  475. # Mask unit or global attention.
  476. # Lag by 1 block, so that global attention,
  477. # applied post pooling on lower resolution
  478. use_mask_unit_attn = mask_unit_attn[cur_stage]
  479. if i - 1 in self.stage_ends:
  480. dim_out = int(embed_dim * dim_mul)
  481. num_heads = int(num_heads * head_mul)
  482. cur_stage += 1
  483. if i in q_pool_blocks:
  484. flat_mu_size //= flat_q_stride
  485. block = HieraBlock(
  486. dim=embed_dim,
  487. dim_out=dim_out,
  488. heads=num_heads,
  489. mlp_ratio=mlp_ratio,
  490. drop_path=dpr[i],
  491. init_values=init_values,
  492. norm_layer=norm_layer,
  493. q_stride=(flat_q_stride if i in q_pool_blocks else 1),
  494. window_size=flat_mu_size,
  495. use_expand_proj=use_expand_proj,
  496. use_mask_unit_attn=use_mask_unit_attn,
  497. **dd,
  498. )
  499. embed_dim = dim_out
  500. if i in self.stage_ends:
  501. self.feature_info += [
  502. dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')]
  503. self.blocks.append(block)
  504. self.num_features = self.head_hidden_size = embed_dim
  505. self.head = ClNormMlpClassifierHead(
  506. embed_dim,
  507. num_classes,
  508. pool_type=global_pool,
  509. drop_rate=drop_rate,
  510. norm_layer=norm_layer,
  511. input_fmt='NLC',
  512. **dd,
  513. )
  514. # Initialize everything
  515. if sep_pos_embed:
  516. nn.init.trunc_normal_(self.pos_embed_spatial, std=0.02)
  517. nn.init.trunc_normal_(self.pos_embed_temporal, std=0.02)
  518. else:
  519. if self.pos_embed is not None:
  520. nn.init.trunc_normal_(self.pos_embed, std=0.02)
  521. if self.pos_embed_win is not None:
  522. nn.init.trunc_normal_(self.pos_embed_win, std=0.02)
  523. if weight_init != 'skip':
  524. init_fn = init_weight_jax if weight_init == 'jax' else init_weight_vit
  525. init_fn = partial(init_fn, classifier_name='head.fc')
  526. named_apply(init_fn, self)
  527. if fix_init:
  528. self.fix_init_weight()
  529. if isinstance(self.head.fc, nn.Linear):
  530. self.head.fc.weight.data.mul_(head_init_scale)
  531. self.head.fc.bias.data.mul_(head_init_scale)
  532. def fix_init_weight(self):
  533. def rescale(param, _layer_id):
  534. param.div_(math.sqrt(2.0 * _layer_id))
  535. for layer_id, layer in enumerate(self.blocks):
  536. rescale(layer.attn.proj.weight.data, layer_id + 1)
  537. rescale(layer.mlp.fc2.weight.data, layer_id + 1)
  538. @torch.jit.ignore
  539. def no_weight_decay(self):
  540. if self.pos_embed is not None:
  541. return ["pos_embed"]
  542. elif self.pos_embed_abs is not None:
  543. return ['pos_embed_abs', 'pos_embed_win']
  544. else:
  545. return ["pos_embed_spatial", "pos_embed_temporal"]
  546. @torch.jit.ignore
  547. def group_matcher(self, coarse: bool = False) -> Dict:
  548. return dict(
  549. stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|pos_embed_abs|pos_embed_win|patch_embed',
  550. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
  551. )
  552. @torch.jit.ignore
  553. def set_grad_checkpointing(self, enable: bool = True) -> None:
  554. self.grad_checkpointing = enable
  555. @torch.jit.ignore
  556. def get_classifier(self):
  557. return self.head.fc
  558. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, reset_other: bool = False):
  559. self.num_classes = num_classes
  560. self.head.reset(num_classes, global_pool, reset_other=reset_other)
  561. def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
  562. """
  563. Generates a random mask, mask_ratio fraction are dropped.
  564. 1 is *keep*, 0 is *remove*. Useful for MAE, FLIP, etc.
  565. """
  566. B = x.shape[0]
  567. # Tokens selected for masking at mask unit level
  568. num_windows = math.prod(self.mask_spatial_shape) # num_mask_units
  569. len_keep = int(num_windows * (1 - mask_ratio))
  570. noise = torch.rand(B, num_windows, device=x.device)
  571. # Sort noise for each sample
  572. ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
  573. ids_restore = torch.argsort(ids_shuffle, dim=1)
  574. # Generate the binary mask: 1 is *keep*, 0 is *remove*
  575. # Note this is opposite to original MAE
  576. mask = torch.zeros([B, num_windows], device=x.device)
  577. mask[:, :len_keep] = 1
  578. # Unshuffle to get the binary mask
  579. mask = torch.gather(mask, dim=1, index=ids_restore)
  580. return mask.bool()
  581. def _pos_embed(self, x) -> torch.Tensor:
  582. if self.pos_embed_win is not None:
  583. # absolute win position embedding, from
  584. # Window Attention is Bugged: How not to Interpolate Position Embeddings (https://arxiv.org/abs/2311.05613)
  585. pos_embed_win = self.pos_embed_win.tile(self.mask_spatial_shape)
  586. pos_embed = F.interpolate(
  587. self.pos_embed,
  588. size=pos_embed_win.shape[-2:],
  589. mode='bicubic',
  590. antialias=True,
  591. )
  592. pos_embed = pos_embed + pos_embed_win
  593. pos_embed = pos_embed.flatten(2).transpose(1, 2)
  594. elif self.pos_embed is not None:
  595. pos_embed = self.pos_embed
  596. else:
  597. pos_embed = (
  598. self.pos_embed_spatial.repeat(1, self.tokens_spatial_shape[0], 1)
  599. +
  600. torch.repeat_interleave(
  601. self.pos_embed_temporal,
  602. self.tokens_spatial_shape[1] * self.tokens_spatial_shape[2],
  603. dim=1,
  604. )
  605. )
  606. x = x + pos_embed
  607. return x
  608. def forward_intermediates(
  609. self,
  610. x: torch.Tensor,
  611. mask: Optional[torch.Tensor] = None,
  612. indices: Optional[Union[int, List[int]]] = None,
  613. norm: bool = False,
  614. stop_early: bool = True,
  615. output_fmt: str = 'NCHW',
  616. intermediates_only: bool = False,
  617. coarse: bool = True,
  618. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  619. """ Forward features that returns intermediates.
  620. Args:
  621. x: Input image tensor
  622. indices: Take last n blocks if int, all if None, select matching indices if sequence
  623. norm: Apply norm layer to all intermediates
  624. stop_early: Stop iterating over blocks when last desired intermediate hit
  625. output_fmt: Shape of intermediate feature outputs
  626. intermediates_only: Only return intermediate features
  627. Returns:
  628. """
  629. assert not norm, 'normalization of features not supported'
  630. assert output_fmt in ('NCHW', 'NHWC'), 'Output format must be one of NCHW, NHWC.'
  631. if coarse:
  632. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  633. take_indices = [self.stage_ends[i] for i in take_indices]
  634. max_index = self.stage_ends[max_index]
  635. else:
  636. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  637. if mask is not None:
  638. patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
  639. else:
  640. patch_mask = None
  641. x = self.patch_embed(x, mask=patch_mask)
  642. x = self._pos_embed(x)
  643. x = self.unroll(x)
  644. # Discard masked tokens
  645. if mask is not None:
  646. x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view(x.shape[0], -1, x.shape[-1])
  647. intermediates = []
  648. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  649. blocks = self.blocks
  650. else:
  651. blocks = self.blocks[:max_index + 1]
  652. for i, blk in enumerate(blocks):
  653. if self.grad_checkpointing and not torch.jit.is_scripting():
  654. x = checkpoint(blk, x)
  655. else:
  656. x = blk(x)
  657. if i in take_indices:
  658. x_int = self.reroll(x, i, mask=mask)
  659. intermediates.append(x_int.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x_int)
  660. if intermediates_only:
  661. return intermediates
  662. return x, intermediates
  663. def prune_intermediate_layers(
  664. self,
  665. indices: Union[int, List[int]] = 1,
  666. prune_norm: bool = False,
  667. prune_head: bool = True,
  668. coarse: bool = True,
  669. ):
  670. """ Prune layers not required for specified intermediates.
  671. """
  672. if coarse:
  673. take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
  674. max_index = self.stage_ends[max_index]
  675. else:
  676. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  677. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  678. if prune_head:
  679. self.head.reset(0, reset_other=True)
  680. return take_indices
  681. def forward_features(
  682. self,
  683. x: torch.Tensor,
  684. mask: Optional[torch.Tensor] = None,
  685. return_intermediates: bool = False,
  686. ) -> torch.Tensor:
  687. """
  688. mask should be a boolean tensor of shape [B, #MUt*#MUy*#MUx] where #MU are the number of mask units in that dim.
  689. Note: 1 in mask is *keep*, 0 is *remove*; mask.sum(dim=-1) should be the same across the batch.
  690. """
  691. if self.training and self.patch_drop_rate > 0:
  692. # using mask for something like 'patch dropout' via mask-units in supervised train / fine-tune
  693. assert mask is None
  694. mask = self.get_random_mask(x, mask_ratio=self.patch_drop_rate)
  695. if mask is not None:
  696. patch_mask = mask.view(x.shape[0], 1, *self.mask_spatial_shape) # B, C, *mask_spatial_shape
  697. else:
  698. patch_mask = None
  699. x = self.patch_embed(x, mask=patch_mask)
  700. x = self._pos_embed(x)
  701. x = self.unroll(x)
  702. # Discard masked tokens
  703. if mask is not None:
  704. x = x[mask[..., None].tile(1, self.mu_size, x.shape[2])].view(x.shape[0], -1, x.shape[-1])
  705. intermediates = []
  706. for i, blk in enumerate(self.blocks):
  707. if self.grad_checkpointing and not torch.jit.is_scripting():
  708. x = checkpoint(blk, x)
  709. else:
  710. x = blk(x)
  711. if return_intermediates and i in self.stage_ends:
  712. intermediates.append(self.reroll(x, i, mask=mask))
  713. # x may not always be in spatial order here.
  714. # e.g. if q_pool = 2, mask_unit_size = (8, 8), and
  715. # q_stride = (2, 2), not all unrolls were consumed,
  716. # intermediates[-1] is x in spatial order
  717. if return_intermediates:
  718. return x, intermediates
  719. return x
  720. def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
  721. x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  722. return x
  723. def forward(
  724. self,
  725. x: torch.Tensor,
  726. mask: Optional[torch.Tensor] = None,
  727. ) -> torch.Tensor:
  728. x = self.forward_features(x, mask=mask)
  729. if mask is None:
  730. x = self.forward_head(x)
  731. return x
  732. def _cfg(url='', **kwargs):
  733. return {
  734. 'url': url,
  735. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  736. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  737. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  738. 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
  739. 'license': 'apache-2.0',
  740. **kwargs
  741. }
  742. default_cfgs = generate_default_cfgs({
  743. "hiera_tiny_224.mae_in1k_ft_in1k": _cfg(
  744. hf_hub_id='timm/',
  745. license='cc-by-nc-4.0',
  746. ),
  747. "hiera_tiny_224.mae": _cfg(
  748. hf_hub_id='timm/',
  749. license='cc-by-nc-4.0',
  750. num_classes=0,
  751. ),
  752. "hiera_small_224.mae_in1k_ft_in1k": _cfg(
  753. hf_hub_id='timm/',
  754. license='cc-by-nc-4.0',
  755. ),
  756. "hiera_small_224.mae": _cfg(
  757. hf_hub_id='timm/',
  758. license='cc-by-nc-4.0',
  759. num_classes=0,
  760. ),
  761. "hiera_base_224.mae_in1k_ft_in1k": _cfg(
  762. hf_hub_id='timm/',
  763. license='cc-by-nc-4.0',
  764. ),
  765. "hiera_base_224.mae": _cfg(
  766. hf_hub_id='timm/',
  767. license='cc-by-nc-4.0',
  768. num_classes=0,
  769. ),
  770. "hiera_base_plus_224.mae_in1k_ft_in1k": _cfg(
  771. hf_hub_id='timm/',
  772. license='cc-by-nc-4.0',
  773. ),
  774. "hiera_base_plus_224.mae": _cfg(
  775. hf_hub_id='timm/',
  776. license='cc-by-nc-4.0',
  777. num_classes=0,
  778. ),
  779. "hiera_large_224.mae_in1k_ft_in1k": _cfg(
  780. hf_hub_id='timm/',
  781. license='cc-by-nc-4.0',
  782. ),
  783. "hiera_large_224.mae": _cfg(
  784. hf_hub_id='timm/',
  785. license='cc-by-nc-4.0',
  786. num_classes=0,
  787. ),
  788. "hiera_huge_224.mae_in1k_ft_in1k": _cfg(
  789. hf_hub_id='timm/',
  790. license='cc-by-nc-4.0',
  791. ),
  792. "hiera_huge_224.mae": _cfg(
  793. hf_hub_id='timm/',
  794. license='cc-by-nc-4.0',
  795. num_classes=0,
  796. ),
  797. "hiera_small_abswin_256.sbb2_e200_in12k_ft_in1k": _cfg(
  798. hf_hub_id='timm/',
  799. input_size=(3, 256, 256), crop_pct=0.95,
  800. ),
  801. "hiera_small_abswin_256.sbb2_pd_e200_in12k_ft_in1k": _cfg(
  802. hf_hub_id='timm/',
  803. input_size=(3, 256, 256), crop_pct=0.95,
  804. ),
  805. "hiera_small_abswin_256.sbb2_e200_in12k": _cfg(
  806. hf_hub_id='timm/',
  807. num_classes=11821,
  808. input_size=(3, 256, 256), crop_pct=0.95,
  809. ),
  810. "hiera_small_abswin_256.sbb2_pd_e200_in12k": _cfg(
  811. hf_hub_id='timm/',
  812. num_classes=11821,
  813. input_size=(3, 256, 256), crop_pct=0.95,
  814. ),
  815. "hiera_base_abswin_256.untrained": _cfg(
  816. # hf_hub_id='timm/',
  817. input_size=(3, 256, 256), crop_pct=0.95,
  818. ),
  819. })
  820. def checkpoint_filter_fn(state_dict, model=None):
  821. state_dict = state_dict.get('model_state', state_dict)
  822. output = {}
  823. for k, v in state_dict.items():
  824. # if k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
  825. # # To resize pos embedding when using model at different size from pretrained weights
  826. # from timm.layers import resample_abs_pos_embed
  827. # v = resample_abs_pos_embed(
  828. # v,
  829. # new_size=(64, 64),
  830. # num_prefix_tokens=0,
  831. # verbose=True,
  832. # )
  833. if 'head.projection.' in k:
  834. k = k.replace('head.projection.', 'head.fc.')
  835. if k.startswith('encoder_norm.'):
  836. k = k.replace('encoder_norm.', 'head.norm.')
  837. elif k.startswith('norm.'):
  838. k = k.replace('norm.', 'head.norm.')
  839. if k == 'pos_embed_abs':
  840. k = 'pos_embed'
  841. output[k] = v
  842. return output
  843. def _create_hiera(variant: str, pretrained: bool = False, **kwargs) -> Hiera:
  844. out_indices = kwargs.pop('out_indices', 4)
  845. return build_model_with_cfg(
  846. Hiera,
  847. variant,
  848. pretrained,
  849. pretrained_filter_fn=checkpoint_filter_fn,
  850. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  851. **kwargs,
  852. )
  853. @register_model
  854. def hiera_tiny_224(pretrained=False, **kwargs):
  855. model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2))
  856. return _create_hiera('hiera_tiny_224', pretrained=pretrained, **dict(model_args, **kwargs))
  857. @register_model
  858. def hiera_small_224(pretrained=False, **kwargs):
  859. model_args = dict(embed_dim=96, num_heads=1, stages=(1, 2, 11, 2))
  860. return _create_hiera('hiera_small_224', pretrained=pretrained, **dict(model_args, **kwargs))
  861. @register_model
  862. def hiera_base_224(pretrained=False, **kwargs):
  863. model_args = dict(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
  864. return _create_hiera('hiera_base_224', pretrained=pretrained, **dict(model_args, **kwargs))
  865. @register_model
  866. def hiera_base_plus_224(pretrained=False, **kwargs):
  867. model_args = dict(embed_dim=112, num_heads=2, stages=(2, 3, 16, 3))
  868. return _create_hiera('hiera_base_plus_224', pretrained=pretrained, **dict(model_args, **kwargs))
  869. @register_model
  870. def hiera_large_224(pretrained=False, **kwargs):
  871. model_args = dict(embed_dim=144, num_heads=2, stages=(2, 6, 36, 4))
  872. return _create_hiera('hiera_large_224', pretrained=pretrained, **dict(model_args, **kwargs))
  873. @register_model
  874. def hiera_huge_224(pretrained=False, **kwargs):
  875. model_args = dict(embed_dim=256, num_heads=4, stages=(2, 6, 36, 4))
  876. return _create_hiera('hiera_huge_224', pretrained=pretrained, **dict(model_args, **kwargs))
  877. @register_model
  878. def hiera_small_abswin_256(pretrained=False, **kwargs):
  879. model_args = dict(
  880. embed_dim=96, num_heads=1, stages=(1, 2, 11, 2), abs_win_pos_embed=True, global_pos_size=(16, 16),
  881. init_values=1e-5, weight_init='jax', use_expand_proj=False,
  882. )
  883. return _create_hiera('hiera_small_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))
  884. @register_model
  885. def hiera_base_abswin_256(pretrained=False, **kwargs):
  886. model_args = dict(
  887. embed_dim=96, num_heads=1, stages=(2, 3, 16, 3), abs_win_pos_embed=True, init_values=1e-5, weight_init='jax')
  888. return _create_hiera('hiera_base_abswin_256', pretrained=pretrained, **dict(model_args, **kwargs))