beit.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991
  1. """ BEiT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
  2. Model from official source: https://github.com/microsoft/unilm/tree/master/beit
  3. @inproceedings{beit,
  4. title={{BEiT}: {BERT} Pre-Training of Image Transformers},
  5. author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei},
  6. booktitle={International Conference on Learning Representations},
  7. year={2022},
  8. url={https://openreview.net/forum?id=p-BhZSz59o4}
  9. }
  10. BEiT-v2 from https://github.com/microsoft/unilm/tree/master/beit2
  11. @article{beitv2,
  12. title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers},
  13. author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei},
  14. year={2022},
  15. eprint={2208.06366},
  16. archivePrefix={arXiv},
  17. primaryClass={cs.CV}
  18. }
  19. At this point only the 1k fine-tuned classification weights and model configs have been added,
  20. see original source above for pre-training models and procedure.
  21. Modifications by / Copyright 2021 Ross Wightman, original copyrights below
  22. """
  23. # --------------------------------------------------------
  24. # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
  25. # Github source: https://github.com/microsoft/unilm/tree/master/beit
  26. # Copyright (c) 2021 Microsoft
  27. # Licensed under The MIT License [see LICENSE for details]
  28. # By Hangbo Bao
  29. # Based on timm and DeiT code bases
  30. # https://github.com/rwightman/pytorch-image-models/tree/master/timm
  31. # https://github.com/facebookresearch/deit/
  32. # https://github.com/facebookresearch/dino
  33. # --------------------------------------------------------'
  34. import math
  35. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
  36. import torch
  37. import torch.nn as nn
  38. import torch.nn.functional as F
  39. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  40. from timm.layers import (
  41. PatchEmbed,
  42. Mlp,
  43. SwiGLU,
  44. LayerNorm,
  45. DropPath,
  46. calculate_drop_path_rates,
  47. trunc_normal_,
  48. use_fused_attn,
  49. resample_patch_embed,
  50. resample_abs_pos_embed,
  51. resize_rel_pos_bias_table,
  52. ndgrid,
  53. )
  54. from ._builder import build_model_with_cfg
  55. from ._features import feature_take_indices
  56. from ._manipulate import checkpoint
  57. from ._registry import generate_default_cfgs, register_model
  58. __all__ = ['Beit']
  59. def gen_relative_position_index(window_size: Tuple[int, int], device=None) -> torch.Tensor:
  60. """Generate relative position index for window-based attention.
  61. Creates a lookup table for relative position indices between all pairs of positions
  62. within a window, including special handling for cls token interactions.
  63. Args:
  64. window_size: Height and width of the attention window.
  65. Returns:
  66. Relative position index tensor of shape (window_area+1, window_area+1)
  67. where +1 accounts for the cls token.
  68. """
  69. num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
  70. # cls to token & token 2 cls & cls to cls
  71. # get pair-wise relative position index for each token inside the window
  72. window_area = window_size[0] * window_size[1]
  73. coords = torch.stack(ndgrid(
  74. torch.arange(window_size[0], device=device, dtype=torch.long),
  75. torch.arange(window_size[1], device=device, dtype=torch.long),
  76. )) # 2, Wh, Ww
  77. coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
  78. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
  79. relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
  80. relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
  81. relative_coords[:, :, 1] += window_size[1] - 1
  82. relative_coords[:, :, 0] *= 2 * window_size[1] - 1
  83. relative_position_index = torch.zeros(size=(window_area + 1,) * 2, device=device, dtype=relative_coords.dtype)
  84. relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
  85. relative_position_index[0, 0:] = num_relative_distance - 3
  86. relative_position_index[0:, 0] = num_relative_distance - 2
  87. relative_position_index[0, 0] = num_relative_distance - 1
  88. return relative_position_index
  89. class Attention(nn.Module):
  90. """Multi-head attention module with optional relative position bias.
  91. Implements multi-head self-attention with support for relative position bias
  92. and fused attention operations. Can use either standard or custom head dimensions.
  93. """
  94. fused_attn: torch.jit.Final[bool]
  95. def __init__(
  96. self,
  97. dim: int,
  98. num_heads: int = 8,
  99. qkv_bias: bool = False,
  100. qkv_bias_separate: bool = False,
  101. attn_drop: float = 0.,
  102. proj_drop: float = 0.,
  103. window_size: Optional[Tuple[int, int]] = None,
  104. attn_head_dim: Optional[int] = None,
  105. device=None,
  106. dtype=None,
  107. ):
  108. """Initialize attention module.
  109. Args:
  110. dim: Input feature dimension.
  111. num_heads: Number of attention heads.
  112. qkv_bias: If True, add learnable bias to query, key, value projections.
  113. qkv_bias_separate: If True, use separate bias for q, k, v projections.
  114. attn_drop: Dropout rate for attention weights.
  115. proj_drop: Dropout rate for output projection.
  116. window_size: Window size for relative position bias. If None, no relative position bias.
  117. attn_head_dim: Dimension per attention head. If None, uses dim // num_heads.
  118. """
  119. dd = {'device': device, 'dtype': dtype}
  120. super().__init__()
  121. self.num_heads = num_heads
  122. head_dim = dim // num_heads
  123. if attn_head_dim is not None:
  124. head_dim = attn_head_dim
  125. all_head_dim = head_dim * self.num_heads
  126. self.scale = head_dim ** -0.5
  127. self.fused_attn = use_fused_attn()
  128. self.qkv_bias_separate = qkv_bias_separate
  129. self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False, **dd)
  130. if qkv_bias:
  131. self.q_bias = nn.Parameter(torch.zeros(all_head_dim, **dd))
  132. self.register_buffer('k_bias', torch.zeros(all_head_dim, **dd), persistent=False)
  133. self.v_bias = nn.Parameter(torch.zeros(all_head_dim, **dd))
  134. else:
  135. self.q_bias = None
  136. self.k_bias = None
  137. self.v_bias = None
  138. if window_size:
  139. self.window_size = window_size
  140. self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
  141. self.relative_position_bias_table = nn.Parameter(
  142. torch.zeros(self.num_relative_distance, num_heads, **dd)) # 2*Wh-1 * 2*Ww-1, nH
  143. self.register_buffer(
  144. "relative_position_index",
  145. gen_relative_position_index(window_size, device=device),
  146. persistent=False,
  147. )
  148. else:
  149. self.window_size = None
  150. self.relative_position_bias_table = None
  151. self.relative_position_index = None
  152. self.attn_drop = nn.Dropout(attn_drop)
  153. self.proj = nn.Linear(all_head_dim, dim, **dd)
  154. self.proj_drop = nn.Dropout(proj_drop)
  155. def _get_rel_pos_bias(self) -> torch.Tensor:
  156. """Get relative position bias for the attention window.
  157. Returns:
  158. Relative position bias tensor of shape (1, num_heads, window_area+1, window_area+1).
  159. """
  160. relative_position_bias = self.relative_position_bias_table[
  161. self.relative_position_index.view(-1)].view(
  162. self.window_size[0] * self.window_size[1] + 1,
  163. self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
  164. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  165. return relative_position_bias.unsqueeze(0)
  166. def forward(self, x: torch.Tensor, shared_rel_pos_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  167. """Forward pass of attention module.
  168. Args:
  169. x: Input tensor of shape (batch_size, num_tokens, dim).
  170. shared_rel_pos_bias: Optional shared relative position bias from parent module.
  171. Returns:
  172. Output tensor of shape (batch_size, num_tokens, dim).
  173. """
  174. B, N, C = x.shape
  175. if self.q_bias is None:
  176. qkv = self.qkv(x)
  177. else:
  178. qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
  179. if self.qkv_bias_separate:
  180. qkv = self.qkv(x)
  181. qkv += qkv_bias
  182. else:
  183. qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
  184. qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
  185. q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
  186. if self.fused_attn:
  187. rel_pos_bias = None
  188. if self.relative_position_bias_table is not None:
  189. rel_pos_bias = self._get_rel_pos_bias()
  190. if shared_rel_pos_bias is not None:
  191. rel_pos_bias = rel_pos_bias + shared_rel_pos_bias
  192. elif shared_rel_pos_bias is not None:
  193. rel_pos_bias = shared_rel_pos_bias
  194. x = F.scaled_dot_product_attention(
  195. q, k, v,
  196. attn_mask=rel_pos_bias,
  197. dropout_p=self.attn_drop.p if self.training else 0.,
  198. )
  199. else:
  200. q = q * self.scale
  201. attn = (q @ k.transpose(-2, -1))
  202. if self.relative_position_bias_table is not None:
  203. attn = attn + self._get_rel_pos_bias()
  204. if shared_rel_pos_bias is not None:
  205. attn = attn + shared_rel_pos_bias
  206. attn = attn.softmax(dim=-1)
  207. attn = self.attn_drop(attn)
  208. x = attn @ v
  209. x = x.transpose(1, 2).reshape(B, N, C)
  210. x = self.proj(x)
  211. x = self.proj_drop(x)
  212. return x
  213. class Block(nn.Module):
  214. """Transformer block with attention and MLP.
  215. Standard transformer block consisting of multi-head self-attention and MLP
  216. with residual connections and layer normalization. Supports layer scale and
  217. stochastic depth regularization.
  218. """
  219. def __init__(
  220. self,
  221. dim: int,
  222. num_heads: int,
  223. qkv_bias: bool = False,
  224. mlp_ratio: float = 4.,
  225. scale_mlp: bool = False,
  226. swiglu_mlp: bool = False,
  227. proj_drop: float = 0.,
  228. attn_drop: float = 0.,
  229. drop_path: float = 0.,
  230. init_values: Optional[float] = None,
  231. act_layer: Type[nn.Module] = nn.GELU,
  232. norm_layer: Type[nn.Module] = LayerNorm,
  233. window_size: Optional[Tuple[int, int]] = None,
  234. attn_head_dim: Optional[int] = None,
  235. device=None,
  236. dtype=None,
  237. ):
  238. """Initialize transformer block.
  239. Args:
  240. dim: Input feature dimension.
  241. num_heads: Number of attention heads.
  242. qkv_bias: If True, add learnable bias to query, key, value projections.
  243. mlp_ratio: Ratio of MLP hidden dimension to input dimension.
  244. scale_mlp: If True, apply layer normalization in MLP.
  245. swiglu_mlp: If True, use SwiGLU activation in MLP.
  246. proj_drop: Dropout rate for projections.
  247. attn_drop: Dropout rate for attention.
  248. drop_path: Drop path rate for stochastic depth.
  249. init_values: Initial values for layer scale. If None, no layer scale.
  250. act_layer: Activation function class.
  251. norm_layer: Normalization layer class.
  252. window_size: Window size for relative position bias in attention.
  253. attn_head_dim: Dimension per attention head.
  254. """
  255. dd = {'device': device, 'dtype': dtype}
  256. super().__init__()
  257. self.norm1 = norm_layer(dim, **dd)
  258. self.attn = Attention(
  259. dim,
  260. num_heads=num_heads,
  261. qkv_bias=qkv_bias,
  262. attn_drop=attn_drop,
  263. proj_drop=proj_drop,
  264. window_size=window_size,
  265. attn_head_dim=attn_head_dim,
  266. **dd,
  267. )
  268. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  269. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  270. self.norm2 = norm_layer(dim, **dd)
  271. if swiglu_mlp:
  272. self.mlp = SwiGLU(
  273. in_features=dim,
  274. hidden_features=int(dim * mlp_ratio),
  275. norm_layer=norm_layer if scale_mlp else None,
  276. drop=proj_drop,
  277. **dd,
  278. )
  279. else:
  280. self.mlp = Mlp(
  281. in_features=dim,
  282. hidden_features=int(dim * mlp_ratio),
  283. act_layer=act_layer,
  284. norm_layer=norm_layer if scale_mlp else None,
  285. drop=proj_drop,
  286. **dd,
  287. )
  288. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  289. if init_values:
  290. self.gamma_1 = nn.Parameter(init_values * torch.ones(dim, **dd))
  291. self.gamma_2 = nn.Parameter(init_values * torch.ones(dim, **dd))
  292. else:
  293. self.gamma_1, self.gamma_2 = None, None
  294. def forward(self, x: torch.Tensor, shared_rel_pos_bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  295. """Forward pass of transformer block.
  296. Args:
  297. x: Input tensor of shape (batch_size, num_tokens, dim).
  298. shared_rel_pos_bias: Optional shared relative position bias.
  299. Returns:
  300. Output tensor of shape (batch_size, num_tokens, dim).
  301. """
  302. if self.gamma_1 is None:
  303. x = x + self.drop_path1(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
  304. x = x + self.drop_path2(self.mlp(self.norm2(x)))
  305. else:
  306. x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
  307. x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
  308. return x
  309. class RelativePositionBias(nn.Module):
  310. """Relative position bias module for window-based attention.
  311. Generates learnable relative position biases for all pairs of positions
  312. within a window, including special handling for cls token.
  313. """
  314. def __init__(self, window_size: Tuple[int, int], num_heads: int, device=None, dtype=None):
  315. """Initialize relative position bias module.
  316. Args:
  317. window_size: Height and width of the attention window.
  318. num_heads: Number of attention heads.
  319. """
  320. dd = {'device': device, 'dtype': dtype}
  321. super().__init__()
  322. self.window_size = window_size
  323. self.window_area = window_size[0] * window_size[1]
  324. num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
  325. self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads, **dd))
  326. # trunc_normal_(self.relative_position_bias_table, std=.02)
  327. self.register_buffer("relative_position_index", gen_relative_position_index(window_size))
  328. def forward(self) -> torch.Tensor:
  329. """Generate relative position bias.
  330. Returns:
  331. Relative position bias tensor of shape (num_heads, window_area+1, window_area+1).
  332. """
  333. relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
  334. self.window_area + 1, self.window_area + 1, -1) # Wh*Ww,Wh*Ww,nH
  335. return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
  336. class Beit(nn.Module):
  337. """BEiT: BERT Pre-Training of Image Transformers.
  338. Vision Transformer model with support for relative position bias and
  339. shared relative position bias across layers. Implements both BEiT v1 and v2
  340. architectures with flexible configuration options.
  341. """
  342. def __init__(
  343. self,
  344. img_size: Union[int, Tuple[int, int]] = 224,
  345. patch_size: Union[int, Tuple[int, int]] = 16,
  346. in_chans: int = 3,
  347. num_classes: int = 1000,
  348. global_pool: str = 'avg',
  349. embed_dim: int = 768,
  350. depth: int = 12,
  351. num_heads: int = 12,
  352. qkv_bias: bool = True,
  353. mlp_ratio: float = 4.,
  354. swiglu_mlp: bool = False,
  355. scale_mlp: bool = False,
  356. drop_rate: float = 0.,
  357. pos_drop_rate: float = 0.,
  358. proj_drop_rate: float = 0.,
  359. attn_drop_rate: float = 0.,
  360. drop_path_rate: float = 0.,
  361. norm_layer: Type[nn.Module] = LayerNorm,
  362. init_values: Optional[float] = None,
  363. use_abs_pos_emb: bool = True,
  364. use_rel_pos_bias: bool = False,
  365. use_shared_rel_pos_bias: bool = False,
  366. head_init_scale: float = 0.001,
  367. device=None,
  368. dtype=None,
  369. ):
  370. """Initialize BEiT model.
  371. Args:
  372. img_size: Input image size.
  373. patch_size: Patch size for patch embedding.
  374. in_chans: Number of input image channels.
  375. num_classes: Number of classes for classification head.
  376. global_pool: Type of global pooling ('avg' or '').
  377. embed_dim: Embedding dimension.
  378. depth: Number of transformer blocks.
  379. num_heads: Number of attention heads.
  380. qkv_bias: If True, add learnable bias to query, key, value projections.
  381. mlp_ratio: Ratio of MLP hidden dimension to embedding dimension.
  382. swiglu_mlp: If True, use SwiGLU activation in MLP.
  383. scale_mlp: If True, apply layer normalization in MLP.
  384. drop_rate: Dropout rate.
  385. pos_drop_rate: Dropout rate for position embeddings.
  386. proj_drop_rate: Dropout rate for projections.
  387. attn_drop_rate: Dropout rate for attention.
  388. drop_path_rate: Stochastic depth rate.
  389. norm_layer: Normalization layer class.
  390. init_values: Initial values for layer scale.
  391. use_abs_pos_emb: If True, use absolute position embeddings.
  392. use_rel_pos_bias: If True, use relative position bias in attention.
  393. use_shared_rel_pos_bias: If True, share relative position bias across layers.
  394. head_init_scale: Scale factor for head initialization.
  395. """
  396. dd = {'device': device, 'dtype': dtype}
  397. super().__init__()
  398. self.num_classes = num_classes
  399. self.global_pool = global_pool
  400. self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
  401. self.num_prefix_tokens = 1
  402. self.grad_checkpointing = False
  403. self.patch_embed = PatchEmbed(
  404. img_size=img_size,
  405. patch_size=patch_size,
  406. in_chans=in_chans,
  407. embed_dim=embed_dim,
  408. **dd,
  409. )
  410. num_patches = self.patch_embed.num_patches
  411. r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
  412. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim, **dd))
  413. # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
  414. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim, **dd)) if use_abs_pos_emb else None
  415. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  416. if use_shared_rel_pos_bias:
  417. self.rel_pos_bias = RelativePositionBias(
  418. window_size=self.patch_embed.grid_size,
  419. num_heads=num_heads,
  420. **dd,
  421. )
  422. else:
  423. self.rel_pos_bias = None
  424. dpr = calculate_drop_path_rates(drop_path_rate, depth) # stochastic depth decay rule
  425. self.blocks = nn.ModuleList([
  426. Block(
  427. dim=embed_dim,
  428. num_heads=num_heads,
  429. qkv_bias=qkv_bias,
  430. mlp_ratio=mlp_ratio,
  431. scale_mlp=scale_mlp,
  432. swiglu_mlp=swiglu_mlp,
  433. proj_drop=proj_drop_rate,
  434. attn_drop=attn_drop_rate,
  435. drop_path=dpr[i],
  436. norm_layer=norm_layer,
  437. init_values=init_values,
  438. window_size=self.patch_embed.grid_size if use_rel_pos_bias else None,
  439. **dd,
  440. )
  441. for i in range(depth)])
  442. self.feature_info = [
  443. dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
  444. use_fc_norm = self.global_pool == 'avg'
  445. self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim, **dd)
  446. self.fc_norm = norm_layer(embed_dim, **dd) if use_fc_norm else nn.Identity()
  447. self.head_drop = nn.Dropout(drop_rate)
  448. self.head = nn.Linear(embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  449. self.apply(self._init_weights)
  450. if self.pos_embed is not None:
  451. trunc_normal_(self.pos_embed, std=.02)
  452. trunc_normal_(self.cls_token, std=.02)
  453. self.fix_init_weight()
  454. if isinstance(self.head, nn.Linear):
  455. trunc_normal_(self.head.weight, std=.02)
  456. self.head.weight.data.mul_(head_init_scale)
  457. self.head.bias.data.mul_(head_init_scale)
  458. def fix_init_weight(self):
  459. """Fix initialization weights according to BEiT paper.
  460. Rescales attention and MLP weights based on layer depth to improve
  461. training stability.
  462. """
  463. def rescale(param, layer_id):
  464. param.div_(math.sqrt(2.0 * layer_id))
  465. for layer_id, layer in enumerate(self.blocks):
  466. rescale(layer.attn.proj.weight.data, layer_id + 1)
  467. rescale(layer.mlp.fc2.weight.data, layer_id + 1)
  468. def _init_weights(self, m: nn.Module):
  469. """Initialize model weights.
  470. Args:
  471. m: Module to initialize.
  472. """
  473. if isinstance(m, nn.Linear):
  474. trunc_normal_(m.weight, std=.02)
  475. if isinstance(m, nn.Linear) and m.bias is not None:
  476. nn.init.constant_(m.bias, 0)
  477. elif isinstance(m, nn.LayerNorm):
  478. nn.init.constant_(m.bias, 0)
  479. nn.init.constant_(m.weight, 1.0)
  480. @torch.jit.ignore
  481. def no_weight_decay(self) -> Set[str]:
  482. """Get parameter names that should not use weight decay.
  483. Returns:
  484. Set of parameter names to exclude from weight decay.
  485. """
  486. nwd = {'pos_embed', 'cls_token'}
  487. for n, _ in self.named_parameters():
  488. if 'relative_position_bias_table' in n:
  489. nwd.add(n)
  490. return nwd
  491. @torch.jit.ignore
  492. def set_grad_checkpointing(self, enable: bool = True):
  493. """Enable or disable gradient checkpointing.
  494. Args:
  495. enable: If True, enable gradient checkpointing.
  496. """
  497. self.grad_checkpointing = enable
  498. @torch.jit.ignore
  499. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  500. """Create parameter group matcher for optimizer parameter groups.
  501. Args:
  502. coarse: If True, use coarse grouping.
  503. Returns:
  504. Dictionary mapping group names to regex patterns.
  505. """
  506. matcher = dict(
  507. stem=r'^cls_token|pos_embed|patch_embed|rel_pos_bias', # stem and embed
  508. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))],
  509. )
  510. return matcher
  511. @torch.jit.ignore
  512. def get_classifier(self) -> nn.Module:
  513. """Get the classifier head.
  514. Returns:
  515. The classification head module.
  516. """
  517. return self.head
  518. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  519. """Reset the classification head.
  520. Args:
  521. num_classes: Number of classes for new head.
  522. global_pool: Global pooling type.
  523. """
  524. self.num_classes = num_classes
  525. if global_pool is not None:
  526. self.global_pool = global_pool
  527. self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  528. def forward_intermediates(
  529. self,
  530. x: torch.Tensor,
  531. indices: Optional[Union[int, List[int]]] = None,
  532. return_prefix_tokens: bool = False,
  533. norm: bool = False,
  534. stop_early: bool = False,
  535. output_fmt: str = 'NCHW',
  536. intermediates_only: bool = False,
  537. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  538. """Forward pass that returns intermediate feature maps.
  539. Args:
  540. x: Input image tensor of shape (batch_size, channels, height, width).
  541. indices: Block indices to return features from. If int, returns last n blocks.
  542. return_prefix_tokens: If True, return both prefix and spatial tokens.
  543. norm: If True, apply normalization to intermediate features.
  544. stop_early: If True, stop at last selected intermediate.
  545. output_fmt: Output format ('NCHW' or 'NLC').
  546. intermediates_only: If True, only return intermediate features.
  547. Returns:
  548. If intermediates_only is True, returns list of intermediate tensors.
  549. Otherwise, returns tuple of (final_features, intermediates).
  550. """
  551. assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
  552. reshape = output_fmt == 'NCHW'
  553. intermediates = []
  554. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  555. # forward pass
  556. B, _, height, width = x.shape
  557. x = self.patch_embed(x)
  558. x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
  559. if self.pos_embed is not None:
  560. x = x + self.pos_embed
  561. x = self.pos_drop(x)
  562. rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
  563. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  564. blocks = self.blocks
  565. else:
  566. blocks = self.blocks[:max_index + 1]
  567. for i, blk in enumerate(blocks):
  568. if self.grad_checkpointing and not torch.jit.is_scripting():
  569. x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
  570. else:
  571. x = blk(x, shared_rel_pos_bias=rel_pos_bias)
  572. if i in take_indices:
  573. # normalize intermediates with final norm layer if enabled
  574. intermediates.append(self.norm(x) if norm else x)
  575. # process intermediates
  576. if self.num_prefix_tokens:
  577. # split prefix (e.g. class, distill) and spatial feature tokens
  578. prefix_tokens = [y[:, 0:self.num_prefix_tokens] for y in intermediates]
  579. intermediates = [y[:, self.num_prefix_tokens:] for y in intermediates]
  580. if reshape:
  581. # reshape to BCHW output format
  582. H, W = self.patch_embed.dynamic_feat_size((height, width))
  583. intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
  584. if not torch.jit.is_scripting() and return_prefix_tokens:
  585. # return_prefix not support in torchscript due to poor type handling
  586. intermediates = list(zip(intermediates, prefix_tokens))
  587. if intermediates_only:
  588. return intermediates
  589. x = self.norm(x)
  590. return x, intermediates
  591. def prune_intermediate_layers(
  592. self,
  593. indices: Union[int, List[int]] = 1,
  594. prune_norm: bool = False,
  595. prune_head: bool = True,
  596. ) -> List[int]:
  597. """Prune layers not required for specified intermediate outputs.
  598. Args:
  599. indices: Indices of blocks to keep.
  600. prune_norm: If True, remove final normalization.
  601. prune_head: If True, remove classification head.
  602. Returns:
  603. List of indices that were kept.
  604. """
  605. take_indices, max_index = feature_take_indices(len(self.blocks), indices)
  606. self.blocks = self.blocks[:max_index + 1] # truncate blocks
  607. if prune_norm:
  608. self.norm = nn.Identity()
  609. if prune_head:
  610. self.fc_norm = nn.Identity()
  611. self.reset_classifier(0, '')
  612. return take_indices
  613. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  614. """Forward pass through feature extraction layers.
  615. Args:
  616. x: Input tensor of shape (batch_size, channels, height, width).
  617. Returns:
  618. Feature tensor of shape (batch_size, num_tokens, embed_dim).
  619. """
  620. x = self.patch_embed(x)
  621. x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
  622. if self.pos_embed is not None:
  623. x = x + self.pos_embed
  624. x = self.pos_drop(x)
  625. rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
  626. for blk in self.blocks:
  627. if self.grad_checkpointing and not torch.jit.is_scripting():
  628. x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
  629. else:
  630. x = blk(x, shared_rel_pos_bias=rel_pos_bias)
  631. x = self.norm(x)
  632. return x
  633. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  634. """Forward pass through classification head.
  635. Args:
  636. x: Feature tensor of shape (batch_size, num_tokens, embed_dim).
  637. pre_logits: If True, return features before final linear layer.
  638. Returns:
  639. Logits tensor of shape (batch_size, num_classes) or pre-logits.
  640. """
  641. if self.global_pool:
  642. x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
  643. x = self.fc_norm(x)
  644. x = self.head_drop(x)
  645. return x if pre_logits else self.head(x)
  646. def forward(self, x: torch.Tensor) -> torch.Tensor:
  647. """Forward pass through the model.
  648. Args:
  649. x: Input tensor of shape (batch_size, channels, height, width).
  650. Returns:
  651. Logits tensor of shape (batch_size, num_classes).
  652. """
  653. x = self.forward_features(x)
  654. x = self.forward_head(x)
  655. return x
  656. def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  657. """Create a default configuration dictionary for BEiT models.
  658. Args:
  659. url: Model weights URL.
  660. **kwargs: Additional configuration parameters.
  661. Returns:
  662. Configuration dictionary.
  663. """
  664. return {
  665. 'url': url,
  666. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  667. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  668. 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
  669. 'first_conv': 'patch_embed.proj', 'classifier': 'head',
  670. 'license': 'apache-2.0',
  671. **kwargs
  672. }
  673. default_cfgs = generate_default_cfgs({
  674. 'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg(
  675. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth',
  676. hf_hub_id='timm/'),
  677. 'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg(
  678. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
  679. hf_hub_id='timm/',
  680. input_size=(3, 384, 384), crop_pct=1.0,
  681. ),
  682. 'beit_base_patch16_224.in22k_ft_in22k': _cfg(
  683. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
  684. hf_hub_id='timm/',
  685. num_classes=21841,
  686. ),
  687. 'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg(
  688. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth',
  689. hf_hub_id='timm/'),
  690. 'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg(
  691. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
  692. hf_hub_id='timm/',
  693. input_size=(3, 384, 384), crop_pct=1.0,
  694. ),
  695. 'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg(
  696. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
  697. hf_hub_id='timm/',
  698. input_size=(3, 512, 512), crop_pct=1.0,
  699. ),
  700. 'beit_large_patch16_224.in22k_ft_in22k': _cfg(
  701. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
  702. hf_hub_id='timm/',
  703. num_classes=21841,
  704. ),
  705. 'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg(
  706. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
  707. hf_hub_id='timm/',
  708. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
  709. ),
  710. 'beitv2_base_patch16_224.in1k_ft_in1k': _cfg(
  711. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft1k.pth',
  712. hf_hub_id='timm/',
  713. mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
  714. ),
  715. 'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
  716. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
  717. hf_hub_id='timm/',
  718. num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
  719. ),
  720. 'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg(
  721. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
  722. hf_hub_id='timm/',
  723. crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
  724. ),
  725. 'beitv2_large_patch16_224.in1k_ft_in1k': _cfg(
  726. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft1k.pth',
  727. hf_hub_id='timm/',
  728. crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
  729. ),
  730. 'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
  731. #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
  732. hf_hub_id='timm/',
  733. num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
  734. ),
  735. })
  736. def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module, interpolation: str = 'bicubic', antialias: bool = True) -> Dict[str, torch.Tensor]:
  737. """Filter and process checkpoint state dict for loading.
  738. Handles resizing of patch embeddings, position embeddings, and relative position
  739. bias tables when model size differs from checkpoint.
  740. Args:
  741. state_dict: Checkpoint state dictionary.
  742. model: Target model to load weights into.
  743. interpolation: Interpolation method for resizing.
  744. antialias: If True, use antialiasing when resizing.
  745. Returns:
  746. Filtered state dictionary.
  747. """
  748. state_dict = state_dict.get('model', state_dict)
  749. state_dict = state_dict.get('module', state_dict)
  750. # beit v2 didn't strip module
  751. out_dict = {}
  752. for k, v in state_dict.items():
  753. if 'relative_position_index' in k:
  754. continue
  755. if 'patch_embed.proj.weight' in k:
  756. O, I, H, W = model.patch_embed.proj.weight.shape
  757. if v.shape[-1] != W or v.shape[-2] != H:
  758. v = resample_patch_embed(
  759. v,
  760. (H, W),
  761. interpolation=interpolation,
  762. antialias=antialias,
  763. verbose=True,
  764. )
  765. elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
  766. # To resize pos embedding when using model at different size from pretrained weights
  767. num_prefix_tokens = 1
  768. v = resample_abs_pos_embed(
  769. v,
  770. new_size=model.patch_embed.grid_size,
  771. num_prefix_tokens=num_prefix_tokens,
  772. interpolation=interpolation,
  773. antialias=antialias,
  774. verbose=True,
  775. )
  776. elif k.endswith('relative_position_bias_table'):
  777. m = model.get_submodule(k[:-29])
  778. if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
  779. v = resize_rel_pos_bias_table(
  780. v,
  781. new_window_size=m.window_size,
  782. new_bias_shape=m.relative_position_bias_table.shape,
  783. )
  784. out_dict[k] = v
  785. return out_dict
  786. def _create_beit(variant: str, pretrained: bool = False, **kwargs) -> Beit:
  787. """Create a BEiT model.
  788. Args:
  789. variant: Model variant name.
  790. pretrained: If True, load pretrained weights.
  791. **kwargs: Additional model arguments.
  792. Returns:
  793. BEiT model instance.
  794. """
  795. out_indices = kwargs.pop('out_indices', 3)
  796. model = build_model_with_cfg(
  797. Beit, variant, pretrained,
  798. pretrained_filter_fn=checkpoint_filter_fn,
  799. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  800. **kwargs,
  801. )
  802. return model
  803. @register_model
  804. def beit_base_patch16_224(pretrained: bool = False, **kwargs) -> Beit:
  805. """BEiT base model @ 224x224 with patch size 16x16."""
  806. model_args = dict(
  807. patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
  808. use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1)
  809. model = _create_beit('beit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  810. return model
  811. @register_model
  812. def beit_base_patch16_384(pretrained: bool = False, **kwargs) -> Beit:
  813. """BEiT base model @ 384x384 with patch size 16x16."""
  814. model_args = dict(
  815. img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
  816. use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1)
  817. model = _create_beit('beit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  818. return model
  819. @register_model
  820. def beit_large_patch16_224(pretrained: bool = False, **kwargs) -> Beit:
  821. """BEiT large model @ 224x224 with patch size 16x16."""
  822. model_args = dict(
  823. patch_size=16, embed_dim=1024, depth=24, num_heads=16,
  824. use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
  825. model = _create_beit('beit_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  826. return model
  827. @register_model
  828. def beit_large_patch16_384(pretrained: bool = False, **kwargs) -> Beit:
  829. """BEiT large model @ 384x384 with patch size 16x16."""
  830. model_args = dict(
  831. img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16,
  832. use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
  833. model = _create_beit('beit_large_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  834. return model
  835. @register_model
  836. def beit_large_patch16_512(pretrained: bool = False, **kwargs) -> Beit:
  837. """BEiT large model @ 512x512 with patch size 16x16."""
  838. model_args = dict(
  839. img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16,
  840. use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
  841. model = _create_beit('beit_large_patch16_512', pretrained=pretrained, **dict(model_args, **kwargs))
  842. return model
  843. @register_model
  844. def beitv2_base_patch16_224(pretrained: bool = False, **kwargs) -> Beit:
  845. """BEiT v2 base model @ 224x224 with patch size 16x16."""
  846. model_args = dict(
  847. patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
  848. use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
  849. model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  850. return model
  851. @register_model
  852. def beitv2_large_patch16_224(pretrained: bool = False, **kwargs) -> Beit:
  853. """BEiT v2 large model @ 224x224 with patch size 16x16."""
  854. model_args = dict(
  855. patch_size=16, embed_dim=1024, depth=24, num_heads=16,
  856. use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
  857. model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  858. return model