deit.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. """ DeiT - Data-efficient Image Transformers
  2. DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below
  3. paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
  4. paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118
  5. Modifications copyright 2021, Ross Wightman
  6. """
  7. # Copyright (c) 2015-present, Facebook, Inc.
  8. # All rights reserved.
  9. from functools import partial
  10. from typing import Optional, Type
  11. import torch
  12. from torch import nn as nn
  13. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  14. from timm.layers import resample_abs_pos_embed
  15. from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
  16. from ._builder import build_model_with_cfg
  17. from ._registry import generate_default_cfgs, register_model, register_model_deprecations
  18. __all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this
  19. class VisionTransformerDistilled(VisionTransformer):
  20. """ Vision Transformer w/ Distillation Token and Head
  21. Distillation token & head support for `DeiT: Data-efficient Image Transformers`
  22. - https://arxiv.org/abs/2012.12877
  23. """
  24. def __init__(self, *args, **kwargs):
  25. weight_init = kwargs.pop('weight_init', '')
  26. super().__init__(*args, **kwargs, weight_init='skip')
  27. assert self.global_pool in ('token',)
  28. dd = {'device': kwargs.get('device', None), 'dtype': kwargs.get('dtype', None)}
  29. self.num_prefix_tokens = 2
  30. self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim, **dd))
  31. self.pos_embed = nn.Parameter(
  32. torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim, **dd))
  33. self.head_dist = nn.Linear(self.embed_dim, self.num_classes, **dd) if self.num_classes > 0 else nn.Identity()
  34. self.distilled_training = False # must set this True to train w/ distillation token
  35. self.init_weights(weight_init)
  36. def init_weights(self, mode=''):
  37. trunc_normal_(self.dist_token, std=.02)
  38. super().init_weights(mode=mode)
  39. @torch.jit.ignore
  40. def group_matcher(self, coarse=False):
  41. return dict(
  42. stem=r'^cls_token|pos_embed|patch_embed|dist_token',
  43. blocks=[
  44. (r'^blocks\.(\d+)', None),
  45. (r'^norm', (99999,))] # final norm w/ last block
  46. )
  47. @torch.jit.ignore
  48. def get_classifier(self) -> nn.Module:
  49. return self.head, self.head_dist
  50. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  51. self.num_classes = num_classes
  52. self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  53. self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
  54. @torch.jit.ignore
  55. def set_distilled_training(self, enable=True):
  56. self.distilled_training = enable
  57. def _pos_embed(self, x):
  58. if self.dynamic_img_size:
  59. B, H, W, C = x.shape
  60. prev_grid_size = self.patch_embed.grid_size
  61. pos_embed = resample_abs_pos_embed(
  62. self.pos_embed,
  63. new_size=(H, W),
  64. old_size=prev_grid_size,
  65. num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
  66. )
  67. x = x.view(B, -1, C)
  68. else:
  69. pos_embed = self.pos_embed
  70. if self.no_embed_class:
  71. # deit-3, updated JAX (big vision)
  72. # position embedding does not overlap with class token, add then concat
  73. x = x + pos_embed
  74. x = torch.cat((
  75. self.cls_token.expand(x.shape[0], -1, -1),
  76. self.dist_token.expand(x.shape[0], -1, -1),
  77. x),
  78. dim=1)
  79. else:
  80. # original timm, JAX, and deit vit impl
  81. # pos_embed has entry for class token, concat then add
  82. x = torch.cat((
  83. self.cls_token.expand(x.shape[0], -1, -1),
  84. self.dist_token.expand(x.shape[0], -1, -1),
  85. x),
  86. dim=1)
  87. x = x + pos_embed
  88. return self.pos_drop(x)
  89. def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
  90. x, x_dist = x[:, 0], x[:, 1]
  91. if pre_logits:
  92. return (x + x_dist) / 2
  93. x = self.head(x)
  94. x_dist = self.head_dist(x_dist)
  95. if self.distilled_training and self.training and not torch.jit.is_scripting():
  96. # only return separate classification predictions when training in distilled mode
  97. return x, x_dist
  98. else:
  99. # during standard train / finetune, inference average the classifier predictions
  100. return (x + x_dist) / 2
  101. def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
  102. out_indices = kwargs.pop('out_indices', 3)
  103. model_cls = VisionTransformerDistilled if distilled else VisionTransformer
  104. model = build_model_with_cfg(
  105. model_cls,
  106. variant,
  107. pretrained,
  108. pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
  109. feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
  110. **kwargs,
  111. )
  112. return model
  113. def _cfg(url='', **kwargs):
  114. return {
  115. 'url': url,
  116. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  117. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  118. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  119. 'first_conv': 'patch_embed.proj', 'classifier': 'head',
  120. 'license': 'apache-2.0',
  121. **kwargs
  122. }
  123. default_cfgs = generate_default_cfgs({
  124. # deit models (FB weights)
  125. 'deit_tiny_patch16_224.fb_in1k': _cfg(
  126. hf_hub_id='timm/',
  127. url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
  128. 'deit_small_patch16_224.fb_in1k': _cfg(
  129. hf_hub_id='timm/',
  130. url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
  131. 'deit_base_patch16_224.fb_in1k': _cfg(
  132. hf_hub_id='timm/',
  133. url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth'),
  134. 'deit_base_patch16_384.fb_in1k': _cfg(
  135. hf_hub_id='timm/',
  136. url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
  137. input_size=(3, 384, 384), crop_pct=1.0),
  138. 'deit_tiny_distilled_patch16_224.fb_in1k': _cfg(
  139. hf_hub_id='timm/',
  140. url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
  141. classifier=('head', 'head_dist')),
  142. 'deit_small_distilled_patch16_224.fb_in1k': _cfg(
  143. hf_hub_id='timm/',
  144. url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
  145. classifier=('head', 'head_dist')),
  146. 'deit_base_distilled_patch16_224.fb_in1k': _cfg(
  147. hf_hub_id='timm/',
  148. url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
  149. classifier=('head', 'head_dist')),
  150. 'deit_base_distilled_patch16_384.fb_in1k': _cfg(
  151. hf_hub_id='timm/',
  152. url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
  153. input_size=(3, 384, 384), crop_pct=1.0,
  154. classifier=('head', 'head_dist')),
  155. 'deit3_small_patch16_224.fb_in1k': _cfg(
  156. hf_hub_id='timm/',
  157. url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'),
  158. 'deit3_small_patch16_384.fb_in1k': _cfg(
  159. hf_hub_id='timm/',
  160. url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth',
  161. input_size=(3, 384, 384), crop_pct=1.0),
  162. 'deit3_medium_patch16_224.fb_in1k': _cfg(
  163. hf_hub_id='timm/',
  164. url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_1k.pth'),
  165. 'deit3_base_patch16_224.fb_in1k': _cfg(
  166. hf_hub_id='timm/',
  167. url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'),
  168. 'deit3_base_patch16_384.fb_in1k': _cfg(
  169. hf_hub_id='timm/',
  170. url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth',
  171. input_size=(3, 384, 384), crop_pct=1.0),
  172. 'deit3_large_patch16_224.fb_in1k': _cfg(
  173. hf_hub_id='timm/',
  174. url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'),
  175. 'deit3_large_patch16_384.fb_in1k': _cfg(
  176. hf_hub_id='timm/',
  177. url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth',
  178. input_size=(3, 384, 384), crop_pct=1.0),
  179. 'deit3_huge_patch14_224.fb_in1k': _cfg(
  180. hf_hub_id='timm/',
  181. url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'),
  182. 'deit3_small_patch16_224.fb_in22k_ft_in1k': _cfg(
  183. hf_hub_id='timm/',
  184. url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth',
  185. crop_pct=1.0),
  186. 'deit3_small_patch16_384.fb_in22k_ft_in1k': _cfg(
  187. hf_hub_id='timm/',
  188. url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth',
  189. input_size=(3, 384, 384), crop_pct=1.0),
  190. 'deit3_medium_patch16_224.fb_in22k_ft_in1k': _cfg(
  191. hf_hub_id='timm/',
  192. url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_21k.pth',
  193. crop_pct=1.0),
  194. 'deit3_base_patch16_224.fb_in22k_ft_in1k': _cfg(
  195. hf_hub_id='timm/',
  196. url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth',
  197. crop_pct=1.0),
  198. 'deit3_base_patch16_384.fb_in22k_ft_in1k': _cfg(
  199. hf_hub_id='timm/',
  200. url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth',
  201. input_size=(3, 384, 384), crop_pct=1.0),
  202. 'deit3_large_patch16_224.fb_in22k_ft_in1k': _cfg(
  203. hf_hub_id='timm/',
  204. url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth',
  205. crop_pct=1.0),
  206. 'deit3_large_patch16_384.fb_in22k_ft_in1k': _cfg(
  207. hf_hub_id='timm/',
  208. url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth',
  209. input_size=(3, 384, 384), crop_pct=1.0),
  210. 'deit3_huge_patch14_224.fb_in22k_ft_in1k': _cfg(
  211. hf_hub_id='timm/',
  212. url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth',
  213. crop_pct=1.0),
  214. })
  215. @register_model
  216. def deit_tiny_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
  217. """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
  218. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  219. """
  220. model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
  221. model = _create_deit('deit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  222. return model
  223. @register_model
  224. def deit_small_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
  225. """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
  226. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  227. """
  228. model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
  229. model = _create_deit('deit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  230. return model
  231. @register_model
  232. def deit_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
  233. """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
  234. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  235. """
  236. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
  237. model = _create_deit('deit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  238. return model
  239. @register_model
  240. def deit_base_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
  241. """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
  242. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  243. """
  244. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
  245. model = _create_deit('deit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  246. return model
  247. @register_model
  248. def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
  249. """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
  250. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  251. """
  252. model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
  253. model = _create_deit(
  254. 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
  255. return model
  256. @register_model
  257. def deit_small_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
  258. """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
  259. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  260. """
  261. model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
  262. model = _create_deit(
  263. 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
  264. return model
  265. @register_model
  266. def deit_base_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
  267. """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
  268. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  269. """
  270. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
  271. model = _create_deit(
  272. 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
  273. return model
  274. @register_model
  275. def deit_base_distilled_patch16_384(pretrained=False, **kwargs) -> VisionTransformerDistilled:
  276. """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
  277. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  278. """
  279. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
  280. model = _create_deit(
  281. 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
  282. return model
  283. @register_model
  284. def deit3_small_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
  285. """ DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
  286. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  287. """
  288. model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6)
  289. model = _create_deit('deit3_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  290. return model
  291. @register_model
  292. def deit3_small_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
  293. """ DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
  294. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  295. """
  296. model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6)
  297. model = _create_deit('deit3_small_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  298. return model
  299. @register_model
  300. def deit3_medium_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
  301. """ DeiT-3 medium model @ 224x224 (https://arxiv.org/abs/2012.12877).
  302. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  303. """
  304. model_args = dict(patch_size=16, embed_dim=512, depth=12, num_heads=8, no_embed_class=True, init_values=1e-6)
  305. model = _create_deit('deit3_medium_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  306. return model
  307. @register_model
  308. def deit3_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
  309. """ DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
  310. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  311. """
  312. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6)
  313. model = _create_deit('deit3_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  314. return model
  315. @register_model
  316. def deit3_base_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
  317. """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
  318. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  319. """
  320. model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6)
  321. model = _create_deit('deit3_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  322. return model
  323. @register_model
  324. def deit3_large_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
  325. """ DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
  326. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  327. """
  328. model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6)
  329. model = _create_deit('deit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
  330. return model
  331. @register_model
  332. def deit3_large_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
  333. """ DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
  334. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  335. """
  336. model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6)
  337. model = _create_deit('deit3_large_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
  338. return model
  339. @register_model
  340. def deit3_huge_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
  341. """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
  342. ImageNet-1k weights from https://github.com/facebookresearch/deit.
  343. """
  344. model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6)
  345. model = _create_deit('deit3_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
  346. return model
  347. register_model_deprecations(__name__, {
  348. 'deit3_small_patch16_224_in21ft1k': 'deit3_small_patch16_224.fb_in22k_ft_in1k',
  349. 'deit3_small_patch16_384_in21ft1k': 'deit3_small_patch16_384.fb_in22k_ft_in1k',
  350. 'deit3_medium_patch16_224_in21ft1k': 'deit3_medium_patch16_224.fb_in22k_ft_in1k',
  351. 'deit3_base_patch16_224_in21ft1k': 'deit3_base_patch16_224.fb_in22k_ft_in1k',
  352. 'deit3_base_patch16_384_in21ft1k': 'deit3_base_patch16_384.fb_in22k_ft_in1k',
  353. 'deit3_large_patch16_224_in21ft1k': 'deit3_large_patch16_224.fb_in22k_ft_in1k',
  354. 'deit3_large_patch16_384_in21ft1k': 'deit3_large_patch16_384.fb_in22k_ft_in1k',
  355. 'deit3_huge_patch14_224_in21ft1k': 'deit3_huge_patch14_224.fb_in22k_ft_in1k'
  356. })