pit.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. """ Pooling-based Vision Transformer (PiT) in PyTorch
  2. A PyTorch implement of Pooling-based Vision Transformers as described in
  3. 'Rethinking Spatial Dimensions of Vision Transformers' - https://arxiv.org/abs/2103.16302
  4. This code was adapted from the original version at https://github.com/naver-ai/pit, original copyright below.
  5. Modifications for timm by / Copyright 2020 Ross Wightman
  6. """
  7. # PiT
  8. # Copyright 2021-present NAVER Corp.
  9. # Apache License v2.0
  10. import math
  11. import re
  12. from functools import partial
  13. from typing import List, Optional, Sequence, Tuple, Union, Type, Any
  14. import torch
  15. from torch import nn
  16. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  17. from timm.layers import trunc_normal_, to_2tuple, calculate_drop_path_rates
  18. from ._builder import build_model_with_cfg
  19. from ._features import feature_take_indices
  20. from ._registry import register_model, generate_default_cfgs
  21. from .vision_transformer import Block
  22. __all__ = ['PoolingVisionTransformer'] # model_registry will add each entrypoint fn to this
  23. class SequentialTuple(nn.Sequential):
  24. """ This module exists to work around torchscript typing issues list -> list"""
  25. def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
  26. for module in self:
  27. x = module(x)
  28. return x
  29. class Transformer(nn.Module):
  30. def __init__(
  31. self,
  32. base_dim: int,
  33. depth: int,
  34. heads: int,
  35. mlp_ratio: float,
  36. pool: Optional[Any] = None,
  37. proj_drop: float = .0,
  38. attn_drop: float = .0,
  39. drop_path_prob: Optional[List[float]] = None,
  40. norm_layer: Optional[Type[nn.Module]] = None,
  41. device=None,
  42. dtype=None,
  43. ):
  44. dd = {'device': device, 'dtype': dtype}
  45. super().__init__()
  46. embed_dim = base_dim * heads
  47. self.pool = pool
  48. self.norm = norm_layer(embed_dim, **dd) if norm_layer else nn.Identity()
  49. self.blocks = nn.Sequential(*[
  50. Block(
  51. dim=embed_dim,
  52. num_heads=heads,
  53. mlp_ratio=mlp_ratio,
  54. qkv_bias=True,
  55. proj_drop=proj_drop,
  56. attn_drop=attn_drop,
  57. drop_path=drop_path_prob[i],
  58. norm_layer=partial(nn.LayerNorm, eps=1e-6),
  59. **dd,
  60. )
  61. for i in range(depth)])
  62. def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
  63. x, cls_tokens = x
  64. token_length = cls_tokens.shape[1]
  65. if self.pool is not None:
  66. x, cls_tokens = self.pool(x, cls_tokens)
  67. B, C, H, W = x.shape
  68. x = x.flatten(2).transpose(1, 2)
  69. x = torch.cat((cls_tokens, x), dim=1)
  70. x = self.norm(x)
  71. x = self.blocks(x)
  72. cls_tokens = x[:, :token_length]
  73. x = x[:, token_length:]
  74. x = x.transpose(1, 2).reshape(B, C, H, W)
  75. return x, cls_tokens
  76. class Pooling(nn.Module):
  77. def __init__(
  78. self,
  79. in_feature: int,
  80. out_feature: int,
  81. stride: int,
  82. padding_mode: str = 'zeros',
  83. device=None,
  84. dtype=None,
  85. ):
  86. dd = {'device': device, 'dtype': dtype}
  87. super().__init__()
  88. self.conv = nn.Conv2d(
  89. in_feature,
  90. out_feature,
  91. kernel_size=stride + 1,
  92. padding=stride // 2,
  93. stride=stride,
  94. padding_mode=padding_mode,
  95. groups=in_feature,
  96. **dd,
  97. )
  98. self.fc = nn.Linear(in_feature, out_feature, **dd)
  99. def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]:
  100. x = self.conv(x)
  101. cls_token = self.fc(cls_token)
  102. return x, cls_token
  103. class ConvEmbedding(nn.Module):
  104. def __init__(
  105. self,
  106. in_channels: int,
  107. out_channels: int,
  108. img_size: int = 224,
  109. patch_size: int = 16,
  110. stride: int = 8,
  111. padding: int = 0,
  112. device=None,
  113. dtype=None,
  114. ):
  115. dd = {'device': device, 'dtype': dtype}
  116. super().__init__()
  117. padding = padding
  118. self.img_size = to_2tuple(img_size)
  119. self.patch_size = to_2tuple(patch_size)
  120. self.height = math.floor((self.img_size[0] + 2 * padding - self.patch_size[0]) / stride + 1)
  121. self.width = math.floor((self.img_size[1] + 2 * padding - self.patch_size[1]) / stride + 1)
  122. self.grid_size = (self.height, self.width)
  123. self.conv = nn.Conv2d(
  124. in_channels,
  125. out_channels,
  126. kernel_size=patch_size,
  127. stride=stride,
  128. padding=padding,
  129. bias=True,
  130. **dd,
  131. )
  132. def forward(self, x):
  133. x = self.conv(x)
  134. return x
  135. class PoolingVisionTransformer(nn.Module):
  136. """ Pooling-based Vision Transformer
  137. A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers'
  138. - https://arxiv.org/abs/2103.16302
  139. """
  140. def __init__(
  141. self,
  142. img_size: int = 224,
  143. patch_size: int = 16,
  144. stride: int = 8,
  145. stem_type: str = 'overlap',
  146. base_dims: Sequence[int] = (48, 48, 48),
  147. depth: Sequence[int] = (2, 6, 4),
  148. heads: Sequence[int] = (2, 4, 8),
  149. mlp_ratio: float = 4,
  150. num_classes: int = 1000,
  151. in_chans: int = 3,
  152. global_pool: str = 'token',
  153. distilled: bool = False,
  154. drop_rate: float = 0.,
  155. pos_drop_drate: float = 0.,
  156. proj_drop_rate: float = 0.,
  157. attn_drop_rate: float = 0.,
  158. drop_path_rate: float = 0.,
  159. device=None,
  160. dtype=None,
  161. ):
  162. super().__init__()
  163. dd = {'device': device, 'dtype': dtype}
  164. assert global_pool in ('token',)
  165. self.base_dims = base_dims
  166. self.heads = heads
  167. embed_dim = base_dims[0] * heads[0]
  168. self.num_classes = num_classes
  169. self.global_pool = global_pool
  170. self.num_tokens = 2 if distilled else 1
  171. self.feature_info = []
  172. self.patch_embed = ConvEmbedding(in_chans, embed_dim, img_size, patch_size, stride, **dd)
  173. self.pos_embed = nn.Parameter(torch.randn(1, embed_dim, self.patch_embed.height, self.patch_embed.width, **dd))
  174. self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, embed_dim, **dd))
  175. self.pos_drop = nn.Dropout(p=pos_drop_drate)
  176. transformers = []
  177. # stochastic depth decay rule
  178. dpr = calculate_drop_path_rates(drop_path_rate, depth, stagewise=True)
  179. prev_dim = embed_dim
  180. for i in range(len(depth)):
  181. pool = None
  182. embed_dim = base_dims[i] * heads[i]
  183. if i > 0:
  184. pool = Pooling(
  185. prev_dim,
  186. embed_dim,
  187. stride=2,
  188. **dd,
  189. )
  190. transformers += [Transformer(
  191. base_dims[i],
  192. depth[i],
  193. heads[i],
  194. mlp_ratio,
  195. pool=pool,
  196. proj_drop=proj_drop_rate,
  197. attn_drop=attn_drop_rate,
  198. drop_path_prob=dpr[i],
  199. **dd,
  200. )]
  201. prev_dim = embed_dim
  202. self.feature_info += [dict(num_chs=prev_dim, reduction=(stride - 1) * 2**i, module=f'transformers.{i}')]
  203. self.transformers = SequentialTuple(*transformers)
  204. self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6, **dd)
  205. self.num_features = self.head_hidden_size = self.embed_dim = embed_dim
  206. # Classifier head
  207. self.head_drop = nn.Dropout(drop_rate)
  208. self.head = nn.Linear(self.embed_dim, num_classes, **dd) if num_classes > 0 else nn.Identity()
  209. self.head_dist = None
  210. if distilled:
  211. self.head_dist = nn.Linear(self.embed_dim, self.num_classes, **dd) if num_classes > 0 else nn.Identity()
  212. self.distilled_training = False # must set this True to train w/ distillation token
  213. trunc_normal_(self.pos_embed, std=.02)
  214. trunc_normal_(self.cls_token, std=.02)
  215. self.apply(self._init_weights)
  216. def _init_weights(self, m):
  217. if isinstance(m, nn.LayerNorm):
  218. nn.init.constant_(m.bias, 0)
  219. nn.init.constant_(m.weight, 1.0)
  220. @torch.jit.ignore
  221. def no_weight_decay(self):
  222. return {'pos_embed', 'cls_token'}
  223. @torch.jit.ignore
  224. def set_distilled_training(self, enable=True):
  225. self.distilled_training = enable
  226. @torch.jit.ignore
  227. def set_grad_checkpointing(self, enable=True):
  228. assert not enable, 'gradient checkpointing not supported'
  229. def get_classifier(self) -> nn.Module:
  230. if self.head_dist is not None:
  231. return self.head, self.head_dist
  232. else:
  233. return self.head
  234. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  235. self.num_classes = num_classes
  236. if global_pool is not None:
  237. self.global_pool = global_pool
  238. device = self.head.weight.device if hasattr(self.head, 'weight') else None
  239. dtype = self.head.weight.dtype if hasattr(self.head, 'weight') else None
  240. self.head = nn.Linear(self.embed_dim, num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()
  241. if self.head_dist is not None:
  242. self.head_dist = nn.Linear(self.embed_dim, self.num_classes, device=device, dtype=dtype) if num_classes > 0 else nn.Identity()
  243. def forward_intermediates(
  244. self,
  245. x: torch.Tensor,
  246. indices: Optional[Union[int, List[int]]] = None,
  247. norm: bool = False,
  248. stop_early: bool = False,
  249. output_fmt: str = 'NCHW',
  250. intermediates_only: bool = False,
  251. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  252. """ Forward features that returns intermediates.
  253. Args:
  254. x: Input image tensor
  255. indices: Take last n blocks if int, all if None, select matching indices if sequence
  256. norm: Apply norm layer to compatible intermediates
  257. stop_early: Stop iterating over blocks when last desired intermediate hit
  258. output_fmt: Shape of intermediate feature outputs
  259. intermediates_only: Only return intermediate features
  260. Returns:
  261. """
  262. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  263. intermediates = []
  264. take_indices, max_index = feature_take_indices(len(self.transformers), indices)
  265. # forward pass
  266. x = self.patch_embed(x)
  267. x = self.pos_drop(x + self.pos_embed)
  268. cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
  269. last_idx = len(self.transformers) - 1
  270. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  271. stages = self.transformers
  272. else:
  273. stages = self.transformers[:max_index + 1]
  274. for feat_idx, stage in enumerate(stages):
  275. x, cls_tokens = stage((x, cls_tokens))
  276. if feat_idx in take_indices:
  277. intermediates.append(x)
  278. if intermediates_only:
  279. return intermediates
  280. if feat_idx == last_idx:
  281. cls_tokens = self.norm(cls_tokens)
  282. return cls_tokens, intermediates
  283. def prune_intermediate_layers(
  284. self,
  285. indices: Union[int, List[int]] = 1,
  286. prune_norm: bool = False,
  287. prune_head: bool = True,
  288. ):
  289. """ Prune layers not required for specified intermediates.
  290. """
  291. take_indices, max_index = feature_take_indices(len(self.transformers), indices)
  292. self.transformers = self.transformers[:max_index + 1] # truncate blocks w/ stem as idx 0
  293. if prune_norm:
  294. self.norm = nn.Identity()
  295. if prune_head:
  296. self.reset_classifier(0, '')
  297. return take_indices
  298. def forward_features(self, x):
  299. x = self.patch_embed(x)
  300. x = self.pos_drop(x + self.pos_embed)
  301. cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
  302. x, cls_tokens = self.transformers((x, cls_tokens))
  303. cls_tokens = self.norm(cls_tokens)
  304. return cls_tokens
  305. def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
  306. if self.head_dist is not None:
  307. assert self.global_pool == 'token'
  308. x, x_dist = x[:, 0], x[:, 1]
  309. x = self.head_drop(x)
  310. x_dist = self.head_drop(x)
  311. if not pre_logits:
  312. x = self.head(x)
  313. x_dist = self.head_dist(x_dist)
  314. if self.distilled_training and self.training and not torch.jit.is_scripting():
  315. # only return separate classification predictions when training in distilled mode
  316. return x, x_dist
  317. else:
  318. # during standard train / finetune, inference average the classifier predictions
  319. return (x + x_dist) / 2
  320. else:
  321. if self.global_pool == 'token':
  322. x = x[:, 0]
  323. x = self.head_drop(x)
  324. if not pre_logits:
  325. x = self.head(x)
  326. return x
  327. def forward(self, x):
  328. x = self.forward_features(x)
  329. x = self.forward_head(x)
  330. return x
  331. def checkpoint_filter_fn(state_dict, model):
  332. """ preprocess checkpoints """
  333. out_dict = {}
  334. p_blocks = re.compile(r'pools\.(\d)\.')
  335. for k, v in state_dict.items():
  336. # FIXME need to update resize for PiT impl
  337. # if k == 'pos_embed' and v.shape != model.pos_embed.shape:
  338. # # To resize pos embedding when using model at different size from pretrained weights
  339. # v = resize_pos_embed(v, model.pos_embed)
  340. k = p_blocks.sub(lambda exp: f'transformers.{int(exp.group(1)) + 1}.pool.', k)
  341. out_dict[k] = v
  342. return out_dict
  343. def _create_pit(variant, pretrained=False, **kwargs):
  344. default_out_indices = tuple(range(3))
  345. out_indices = kwargs.pop('out_indices', default_out_indices)
  346. model = build_model_with_cfg(
  347. PoolingVisionTransformer,
  348. variant,
  349. pretrained,
  350. pretrained_filter_fn=checkpoint_filter_fn,
  351. feature_cfg=dict(feature_cls='hook', out_indices=out_indices),
  352. **kwargs,
  353. )
  354. return model
  355. def _cfg(url='', **kwargs):
  356. return {
  357. 'url': url,
  358. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  359. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  360. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  361. 'first_conv': 'patch_embed.conv', 'classifier': 'head',
  362. 'license': 'apache-2.0',
  363. **kwargs
  364. }
  365. default_cfgs = generate_default_cfgs({
  366. # deit models (FB weights)
  367. 'pit_ti_224.in1k': _cfg(hf_hub_id='timm/'),
  368. 'pit_xs_224.in1k': _cfg(hf_hub_id='timm/'),
  369. 'pit_s_224.in1k': _cfg(hf_hub_id='timm/'),
  370. 'pit_b_224.in1k': _cfg(hf_hub_id='timm/'),
  371. 'pit_ti_distilled_224.in1k': _cfg(
  372. hf_hub_id='timm/',
  373. classifier=('head', 'head_dist')),
  374. 'pit_xs_distilled_224.in1k': _cfg(
  375. hf_hub_id='timm/',
  376. classifier=('head', 'head_dist')),
  377. 'pit_s_distilled_224.in1k': _cfg(
  378. hf_hub_id='timm/',
  379. classifier=('head', 'head_dist')),
  380. 'pit_b_distilled_224.in1k': _cfg(
  381. hf_hub_id='timm/',
  382. classifier=('head', 'head_dist')),
  383. })
  384. @register_model
  385. def pit_b_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
  386. model_args = dict(
  387. patch_size=14,
  388. stride=7,
  389. base_dims=[64, 64, 64],
  390. depth=[3, 6, 4],
  391. heads=[4, 8, 16],
  392. mlp_ratio=4,
  393. )
  394. return _create_pit('pit_b_224', pretrained, **dict(model_args, **kwargs))
  395. @register_model
  396. def pit_s_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
  397. model_args = dict(
  398. patch_size=16,
  399. stride=8,
  400. base_dims=[48, 48, 48],
  401. depth=[2, 6, 4],
  402. heads=[3, 6, 12],
  403. mlp_ratio=4,
  404. )
  405. return _create_pit('pit_s_224', pretrained, **dict(model_args, **kwargs))
  406. @register_model
  407. def pit_xs_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
  408. model_args = dict(
  409. patch_size=16,
  410. stride=8,
  411. base_dims=[48, 48, 48],
  412. depth=[2, 6, 4],
  413. heads=[2, 4, 8],
  414. mlp_ratio=4,
  415. )
  416. return _create_pit('pit_xs_224', pretrained, **dict(model_args, **kwargs))
  417. @register_model
  418. def pit_ti_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
  419. model_args = dict(
  420. patch_size=16,
  421. stride=8,
  422. base_dims=[32, 32, 32],
  423. depth=[2, 6, 4],
  424. heads=[2, 4, 8],
  425. mlp_ratio=4,
  426. )
  427. return _create_pit('pit_ti_224', pretrained, **dict(model_args, **kwargs))
  428. @register_model
  429. def pit_b_distilled_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
  430. model_args = dict(
  431. patch_size=14,
  432. stride=7,
  433. base_dims=[64, 64, 64],
  434. depth=[3, 6, 4],
  435. heads=[4, 8, 16],
  436. mlp_ratio=4,
  437. distilled=True,
  438. )
  439. return _create_pit('pit_b_distilled_224', pretrained, **dict(model_args, **kwargs))
  440. @register_model
  441. def pit_s_distilled_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
  442. model_args = dict(
  443. patch_size=16,
  444. stride=8,
  445. base_dims=[48, 48, 48],
  446. depth=[2, 6, 4],
  447. heads=[3, 6, 12],
  448. mlp_ratio=4,
  449. distilled=True,
  450. )
  451. return _create_pit('pit_s_distilled_224', pretrained, **dict(model_args, **kwargs))
  452. @register_model
  453. def pit_xs_distilled_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
  454. model_args = dict(
  455. patch_size=16,
  456. stride=8,
  457. base_dims=[48, 48, 48],
  458. depth=[2, 6, 4],
  459. heads=[2, 4, 8],
  460. mlp_ratio=4,
  461. distilled=True,
  462. )
  463. return _create_pit('pit_xs_distilled_224', pretrained, **dict(model_args, **kwargs))
  464. @register_model
  465. def pit_ti_distilled_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
  466. model_args = dict(
  467. patch_size=16,
  468. stride=8,
  469. base_dims=[32, 32, 32],
  470. depth=[2, 6, 4],
  471. heads=[2, 4, 8],
  472. mlp_ratio=4,
  473. distilled=True,
  474. )
  475. return _create_pit('pit_ti_distilled_224', pretrained, **dict(model_args, **kwargs))