visformer.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. """ Visformer
  2. Paper: Visformer: The Vision-friendly Transformer - https://arxiv.org/abs/2104.12533
  3. From original at https://github.com/danczs/Visformer
  4. Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
  5. """
  6. from typing import Optional, Union, Type, Any
  7. import torch
  8. import torch.nn as nn
  9. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  10. from timm.layers import to_2tuple, trunc_normal_, DropPath, calculate_drop_path_rates, PatchEmbed, LayerNorm2d, create_classifier, use_fused_attn
  11. from ._builder import build_model_with_cfg
  12. from ._manipulate import checkpoint_seq
  13. from ._registry import register_model, generate_default_cfgs
  14. __all__ = ['Visformer']
  15. class SpatialMlp(nn.Module):
  16. def __init__(
  17. self,
  18. in_features: int,
  19. hidden_features: Optional[int] = None,
  20. out_features: Optional[int] = None,
  21. act_layer: Type[nn.Module] = nn.GELU,
  22. drop: float = 0.,
  23. group: int = 8,
  24. spatial_conv: bool = False,
  25. device=None,
  26. dtype=None,
  27. ):
  28. dd = {'device': device, 'dtype': dtype}
  29. super().__init__()
  30. out_features = out_features or in_features
  31. hidden_features = hidden_features or in_features
  32. drop_probs = to_2tuple(drop)
  33. self.in_features = in_features
  34. self.out_features = out_features
  35. self.spatial_conv = spatial_conv
  36. if self.spatial_conv:
  37. if group < 2: # net setting
  38. hidden_features = in_features * 5 // 6
  39. else:
  40. hidden_features = in_features * 2
  41. self.hidden_features = hidden_features
  42. self.group = group
  43. self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False, **dd)
  44. self.act1 = act_layer()
  45. self.drop1 = nn.Dropout(drop_probs[0])
  46. if self.spatial_conv:
  47. self.conv2 = nn.Conv2d(
  48. hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False, **dd)
  49. self.act2 = act_layer()
  50. else:
  51. self.conv2 = None
  52. self.act2 = None
  53. self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False, **dd)
  54. self.drop3 = nn.Dropout(drop_probs[1])
  55. def forward(self, x):
  56. x = self.conv1(x)
  57. x = self.act1(x)
  58. x = self.drop1(x)
  59. if self.conv2 is not None:
  60. x = self.conv2(x)
  61. x = self.act2(x)
  62. x = self.conv3(x)
  63. x = self.drop3(x)
  64. return x
  65. class Attention(nn.Module):
  66. fused_attn: torch.jit.Final[bool]
  67. def __init__(
  68. self,
  69. dim: int,
  70. num_heads: int = 8,
  71. head_dim_ratio: float = 1.,
  72. attn_drop: float = 0.,
  73. proj_drop: float = 0.,
  74. device=None,
  75. dtype=None,
  76. ):
  77. dd = {'device': device, 'dtype': dtype}
  78. super().__init__()
  79. self.dim = dim
  80. self.num_heads = num_heads
  81. head_dim = round(dim // num_heads * head_dim_ratio)
  82. self.head_dim = head_dim
  83. self.scale = head_dim ** -0.5
  84. self.fused_attn = use_fused_attn(experimental=True)
  85. self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False, **dd)
  86. self.attn_drop = nn.Dropout(attn_drop)
  87. self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False, **dd)
  88. self.proj_drop = nn.Dropout(proj_drop)
  89. def forward(self, x):
  90. B, C, H, W = x.shape
  91. x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3)
  92. q, k, v = x.unbind(0)
  93. if self.fused_attn:
  94. x = torch.nn.functional.scaled_dot_product_attention(
  95. q.contiguous(), k.contiguous(), v.contiguous(),
  96. dropout_p=self.attn_drop.p if self.training else 0.,
  97. )
  98. else:
  99. attn = (q @ k.transpose(-2, -1)) * self.scale
  100. attn = attn.softmax(dim=-1)
  101. attn = self.attn_drop(attn)
  102. x = attn @ v
  103. x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W)
  104. x = self.proj(x)
  105. x = self.proj_drop(x)
  106. return x
  107. class Block(nn.Module):
  108. def __init__(
  109. self,
  110. dim: int,
  111. num_heads: int,
  112. head_dim_ratio: float = 1.,
  113. mlp_ratio: float = 4.,
  114. proj_drop: float = 0.,
  115. attn_drop: float = 0.,
  116. drop_path: float = 0.,
  117. act_layer: Type[nn.Module] = nn.GELU,
  118. norm_layer: Type[nn.Module] = LayerNorm2d,
  119. group: int = 8,
  120. attn_disabled: bool = False,
  121. spatial_conv: bool = False,
  122. device=None,
  123. dtype=None,
  124. ):
  125. dd = {'device': device, 'dtype': dtype}
  126. super().__init__()
  127. self.spatial_conv = spatial_conv
  128. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  129. if attn_disabled:
  130. self.norm1 = None
  131. self.attn = None
  132. else:
  133. self.norm1 = norm_layer(dim, **dd)
  134. self.attn = Attention(
  135. dim,
  136. num_heads=num_heads,
  137. head_dim_ratio=head_dim_ratio,
  138. attn_drop=attn_drop,
  139. proj_drop=proj_drop,
  140. **dd,
  141. )
  142. self.norm2 = norm_layer(dim, **dd)
  143. self.mlp = SpatialMlp(
  144. in_features=dim,
  145. hidden_features=int(dim * mlp_ratio),
  146. act_layer=act_layer,
  147. drop=proj_drop,
  148. group=group,
  149. spatial_conv=spatial_conv,
  150. **dd,
  151. )
  152. def forward(self, x):
  153. if self.attn is not None:
  154. x = x + self.drop_path(self.attn(self.norm1(x)))
  155. x = x + self.drop_path(self.mlp(self.norm2(x)))
  156. return x
  157. class Visformer(nn.Module):
  158. def __init__(
  159. self,
  160. img_size: int = 224,
  161. patch_size: int = 16,
  162. in_chans: int = 3,
  163. num_classes: int = 1000,
  164. init_channels: Optional[int] = 32,
  165. embed_dim: int = 384,
  166. depth: Union[int, tuple] = 12,
  167. num_heads: int = 6,
  168. mlp_ratio: float = 4.,
  169. drop_rate: float = 0.,
  170. pos_drop_rate: float = 0.,
  171. proj_drop_rate: float = 0.,
  172. attn_drop_rate: float = 0.,
  173. drop_path_rate: float = 0.,
  174. norm_layer: Type[nn.Module] = LayerNorm2d,
  175. attn_stage: str = '111',
  176. use_pos_embed: bool = True,
  177. spatial_conv: str = '111',
  178. vit_stem: bool = False,
  179. group: int = 8,
  180. global_pool: str = 'avg',
  181. conv_init: bool = False,
  182. embed_norm: Optional[Type[nn.Module]] = None,
  183. device=None,
  184. dtype=None,
  185. ):
  186. super().__init__()
  187. dd = {'device': device, 'dtype': dtype}
  188. img_size = to_2tuple(img_size)
  189. self.num_classes = num_classes
  190. self.embed_dim = embed_dim
  191. self.init_channels = init_channels
  192. self.img_size = img_size
  193. self.vit_stem = vit_stem
  194. self.conv_init = conv_init
  195. if isinstance(depth, (list, tuple)):
  196. self.stage_num1, self.stage_num2, self.stage_num3 = depth
  197. depth = sum(depth)
  198. else:
  199. self.stage_num1 = self.stage_num3 = depth // 3
  200. self.stage_num2 = depth - self.stage_num1 - self.stage_num3
  201. self.use_pos_embed = use_pos_embed
  202. self.grad_checkpointing = False
  203. dpr = calculate_drop_path_rates(drop_path_rate, depth)
  204. # stage 1
  205. if self.vit_stem:
  206. self.stem = None
  207. self.patch_embed1 = PatchEmbed(
  208. img_size=img_size,
  209. patch_size=patch_size,
  210. in_chans=in_chans,
  211. embed_dim=embed_dim,
  212. norm_layer=embed_norm,
  213. flatten=False,
  214. **dd,
  215. )
  216. img_size = [x // patch_size for x in img_size]
  217. else:
  218. if self.init_channels is None:
  219. self.stem = None
  220. self.patch_embed1 = PatchEmbed(
  221. img_size=img_size,
  222. patch_size=patch_size // 2,
  223. in_chans=in_chans,
  224. embed_dim=embed_dim // 2,
  225. norm_layer=embed_norm,
  226. flatten=False,
  227. **dd,
  228. )
  229. img_size = [x // (patch_size // 2) for x in img_size]
  230. else:
  231. self.stem = nn.Sequential(
  232. nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False, **dd),
  233. nn.BatchNorm2d(self.init_channels, **dd),
  234. nn.ReLU(inplace=True)
  235. )
  236. img_size = [x // 2 for x in img_size]
  237. self.patch_embed1 = PatchEmbed(
  238. img_size=img_size,
  239. patch_size=patch_size // 4,
  240. in_chans=self.init_channels,
  241. embed_dim=embed_dim // 2,
  242. norm_layer=embed_norm,
  243. flatten=False,
  244. **dd,
  245. )
  246. img_size = [x // (patch_size // 4) for x in img_size]
  247. if self.use_pos_embed:
  248. if self.vit_stem:
  249. self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, *img_size, **dd))
  250. else:
  251. self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size, **dd))
  252. self.pos_drop = nn.Dropout(p=pos_drop_rate)
  253. else:
  254. self.pos_embed1 = None
  255. self.stage1 = nn.Sequential(*[
  256. Block(
  257. dim=embed_dim//2,
  258. num_heads=num_heads,
  259. head_dim_ratio=0.5,
  260. mlp_ratio=mlp_ratio,
  261. proj_drop=proj_drop_rate,
  262. attn_drop=attn_drop_rate,
  263. drop_path=dpr[i],
  264. norm_layer=norm_layer,
  265. group=group,
  266. attn_disabled=(attn_stage[0] == '0'),
  267. spatial_conv=(spatial_conv[0] == '1'),
  268. **dd,
  269. )
  270. for i in range(self.stage_num1)
  271. ])
  272. # stage2
  273. if not self.vit_stem:
  274. self.patch_embed2 = PatchEmbed(
  275. img_size=img_size,
  276. patch_size=patch_size // 8,
  277. in_chans=embed_dim // 2,
  278. embed_dim=embed_dim,
  279. norm_layer=embed_norm,
  280. flatten=False,
  281. **dd,
  282. )
  283. img_size = [x // (patch_size // 8) for x in img_size]
  284. if self.use_pos_embed:
  285. self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size, **dd))
  286. else:
  287. self.pos_embed2 = None
  288. else:
  289. self.patch_embed2 = None
  290. self.stage2 = nn.Sequential(*[
  291. Block(
  292. dim=embed_dim,
  293. num_heads=num_heads,
  294. head_dim_ratio=1.0,
  295. mlp_ratio=mlp_ratio,
  296. proj_drop=proj_drop_rate,
  297. attn_drop=attn_drop_rate,
  298. drop_path=dpr[i],
  299. norm_layer=norm_layer,
  300. group=group,
  301. attn_disabled=(attn_stage[1] == '0'),
  302. spatial_conv=(spatial_conv[1] == '1'),
  303. **dd,
  304. )
  305. for i in range(self.stage_num1, self.stage_num1+self.stage_num2)
  306. ])
  307. # stage 3
  308. if not self.vit_stem:
  309. self.patch_embed3 = PatchEmbed(
  310. img_size=img_size,
  311. patch_size=patch_size // 8,
  312. in_chans=embed_dim,
  313. embed_dim=embed_dim * 2,
  314. norm_layer=embed_norm,
  315. flatten=False,
  316. **dd,
  317. )
  318. img_size = [x // (patch_size // 8) for x in img_size]
  319. if self.use_pos_embed:
  320. self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size, **dd))
  321. else:
  322. self.pos_embed3 = None
  323. else:
  324. self.patch_embed3 = None
  325. self.stage3 = nn.Sequential(*[
  326. Block(
  327. dim=embed_dim * 2,
  328. num_heads=num_heads,
  329. head_dim_ratio=1.0,
  330. mlp_ratio=mlp_ratio,
  331. proj_drop=proj_drop_rate,
  332. attn_drop=attn_drop_rate,
  333. drop_path=dpr[i],
  334. norm_layer=norm_layer,
  335. group=group,
  336. attn_disabled=(attn_stage[2] == '0'),
  337. spatial_conv=(spatial_conv[2] == '1'),
  338. **dd,
  339. )
  340. for i in range(self.stage_num1+self.stage_num2, depth)
  341. ])
  342. self.num_features = self.head_hidden_size = embed_dim if self.vit_stem else embed_dim * 2
  343. self.norm = norm_layer(self.num_features, **dd)
  344. # head
  345. global_pool, head = create_classifier(
  346. self.num_features,
  347. self.num_classes,
  348. pool_type=global_pool,
  349. device=device,
  350. dtype=dtype,
  351. )
  352. self.global_pool = global_pool
  353. self.head_drop = nn.Dropout(drop_rate)
  354. self.head = head
  355. # weights init
  356. if self.use_pos_embed:
  357. trunc_normal_(self.pos_embed1, std=0.02)
  358. if not self.vit_stem:
  359. trunc_normal_(self.pos_embed2, std=0.02)
  360. trunc_normal_(self.pos_embed3, std=0.02)
  361. self.apply(self._init_weights)
  362. def _init_weights(self, m):
  363. if isinstance(m, nn.Linear):
  364. trunc_normal_(m.weight, std=0.02)
  365. if m.bias is not None:
  366. nn.init.constant_(m.bias, 0)
  367. elif isinstance(m, nn.Conv2d):
  368. if self.conv_init:
  369. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  370. else:
  371. trunc_normal_(m.weight, std=0.02)
  372. if m.bias is not None:
  373. nn.init.constant_(m.bias, 0.)
  374. @torch.jit.ignore
  375. def group_matcher(self, coarse=False):
  376. return dict(
  377. stem=r'^patch_embed1|pos_embed1|stem', # stem and embed
  378. blocks=[
  379. (r'^stage(\d+)\.(\d+)' if coarse else r'^stage(\d+)\.(\d+)', None),
  380. (r'^(?:patch_embed|pos_embed)(\d+)', (0,)),
  381. (r'^norm', (99999,))
  382. ]
  383. )
  384. @torch.jit.ignore
  385. def set_grad_checkpointing(self, enable=True):
  386. self.grad_checkpointing = enable
  387. @torch.jit.ignore
  388. def get_classifier(self) -> nn.Module:
  389. return self.head
  390. def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
  391. self.num_classes = num_classes
  392. device = self.head.weight.device if hasattr(self.head, 'weight') else None
  393. dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None
  394. self.global_pool, self.head = create_classifier(
  395. self.num_features, self.num_classes, pool_type=global_pool, device=device, dtype=dtype)
  396. def forward_features(self, x):
  397. if self.stem is not None:
  398. x = self.stem(x)
  399. # stage 1
  400. x = self.patch_embed1(x)
  401. if self.pos_embed1 is not None:
  402. x = self.pos_drop(x + self.pos_embed1)
  403. if self.grad_checkpointing and not torch.jit.is_scripting():
  404. x = checkpoint_seq(self.stage1, x)
  405. else:
  406. x = self.stage1(x)
  407. # stage 2
  408. if self.patch_embed2 is not None:
  409. x = self.patch_embed2(x)
  410. if self.pos_embed2 is not None:
  411. x = self.pos_drop(x + self.pos_embed2)
  412. if self.grad_checkpointing and not torch.jit.is_scripting():
  413. x = checkpoint_seq(self.stage2, x)
  414. else:
  415. x = self.stage2(x)
  416. # stage3
  417. if self.patch_embed3 is not None:
  418. x = self.patch_embed3(x)
  419. if self.pos_embed3 is not None:
  420. x = self.pos_drop(x + self.pos_embed3)
  421. if self.grad_checkpointing and not torch.jit.is_scripting():
  422. x = checkpoint_seq(self.stage3, x)
  423. else:
  424. x = self.stage3(x)
  425. x = self.norm(x)
  426. return x
  427. def forward_head(self, x, pre_logits: bool = False):
  428. x = self.global_pool(x)
  429. x = self.head_drop(x)
  430. return x if pre_logits else self.head(x)
  431. def forward(self, x):
  432. x = self.forward_features(x)
  433. x = self.forward_head(x)
  434. return x
  435. def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs):
  436. if kwargs.get('features_only', None):
  437. raise RuntimeError('features_only not implemented for Vision Transformer models.')
  438. model = build_model_with_cfg(Visformer, variant, pretrained, **kwargs)
  439. return model
  440. def _cfg(url='', **kwargs):
  441. return {
  442. 'url': url,
  443. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  444. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  445. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  446. 'first_conv': 'stem.0', 'classifier': 'head',
  447. 'license': 'apache-2.0',
  448. **kwargs
  449. }
  450. default_cfgs = generate_default_cfgs({
  451. 'visformer_tiny.in1k': _cfg(hf_hub_id='timm/'),
  452. 'visformer_small.in1k': _cfg(hf_hub_id='timm/'),
  453. })
  454. @register_model
  455. def visformer_tiny(pretrained=False, **kwargs) -> Visformer:
  456. model_cfg = dict(
  457. init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
  458. attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
  459. embed_norm=nn.BatchNorm2d)
  460. model = _create_visformer('visformer_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs))
  461. return model
  462. @register_model
  463. def visformer_small(pretrained=False, **kwargs) -> Visformer:
  464. model_cfg = dict(
  465. init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
  466. attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
  467. embed_norm=nn.BatchNorm2d)
  468. model = _create_visformer('visformer_small', pretrained=pretrained, **dict(model_cfg, **kwargs))
  469. return model
  470. # @register_model
  471. # def visformer_net1(pretrained=False, **kwargs):
  472. # model = Visformer(
  473. # init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
  474. # spatial_conv='000', vit_stem=True, conv_init=True, **kwargs)
  475. # model.default_cfg = _cfg()
  476. # return model
  477. #
  478. #
  479. # @register_model
  480. # def visformer_net2(pretrained=False, **kwargs):
  481. # model = Visformer(
  482. # init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
  483. # spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
  484. # model.default_cfg = _cfg()
  485. # return model
  486. #
  487. #
  488. # @register_model
  489. # def visformer_net3(pretrained=False, **kwargs):
  490. # model = Visformer(
  491. # init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
  492. # spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
  493. # model.default_cfg = _cfg()
  494. # return model
  495. #
  496. #
  497. # @register_model
  498. # def visformer_net4(pretrained=False, **kwargs):
  499. # model = Visformer(
  500. # init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
  501. # spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
  502. # model.default_cfg = _cfg()
  503. # return model
  504. #
  505. #
  506. # @register_model
  507. # def visformer_net5(pretrained=False, **kwargs):
  508. # model = Visformer(
  509. # init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
  510. # spatial_conv='111', vit_stem=False, conv_init=True, **kwargs)
  511. # model.default_cfg = _cfg()
  512. # return model
  513. #
  514. #
  515. # @register_model
  516. # def visformer_net6(pretrained=False, **kwargs):
  517. # model = Visformer(
  518. # init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
  519. # pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
  520. # model.default_cfg = _cfg()
  521. # return model
  522. #
  523. #
  524. # @register_model
  525. # def visformer_net7(pretrained=False, **kwargs):
  526. # model = Visformer(
  527. # init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000',
  528. # pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
  529. # model.default_cfg = _cfg()
  530. # return model