levit.py 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091
  1. """ LeViT
  2. Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference`
  3. - https://arxiv.org/abs/2104.01136
  4. @article{graham2021levit,
  5. title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
  6. author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze},
  7. journal={arXiv preprint arXiv:22104.01136},
  8. year={2021}
  9. }
  10. Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow.
  11. This version combines both conv/linear models and fixes torchscript compatibility.
  12. Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
  13. """
  14. # Copyright (c) 2015-present, Facebook, Inc.
  15. # All rights reserved.
  16. # Modified from
  17. # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
  18. # Copyright 2020 Ross Wightman, Apache-2.0 License
  19. from collections import OrderedDict
  20. from functools import partial
  21. from typing import Dict, List, Optional, Tuple, Type, Union
  22. import torch
  23. import torch.nn as nn
  24. from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
  25. from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_, ndgrid
  26. from ._builder import build_model_with_cfg
  27. from ._features import feature_take_indices
  28. from ._manipulate import checkpoint, checkpoint_seq
  29. from ._registry import generate_default_cfgs, register_model
  30. __all__ = ['Levit']
  31. class ConvNorm(nn.Module):
  32. def __init__(
  33. self,
  34. in_chs: int,
  35. out_chs: int,
  36. kernel_size: int = 1,
  37. stride: int = 1,
  38. padding: int = 0,
  39. dilation: int = 1,
  40. groups: int = 1,
  41. bn_weight_init: float = 1,
  42. device=None,
  43. dtype=None,
  44. ):
  45. dd = {'device': device, 'dtype': dtype}
  46. super().__init__()
  47. self.linear = nn.Conv2d(in_chs, out_chs, kernel_size, stride, padding, dilation, groups, bias=False, **dd)
  48. self.bn = nn.BatchNorm2d(out_chs, **dd)
  49. nn.init.constant_(self.bn.weight, bn_weight_init)
  50. @torch.no_grad()
  51. def fuse(self):
  52. c, bn = self.linear, self.bn
  53. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  54. w = c.weight * w[:, None, None, None]
  55. b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
  56. m = nn.Conv2d(
  57. w.size(1), w.size(0), w.shape[2:], stride=self.linear.stride,
  58. padding=self.linear.padding, dilation=self.linear.dilation, groups=self.linear.groups)
  59. m.weight.data.copy_(w)
  60. m.bias.data.copy_(b)
  61. return m
  62. def forward(self, x):
  63. return self.bn(self.linear(x))
  64. class LinearNorm(nn.Module):
  65. def __init__(
  66. self,
  67. in_features: int,
  68. out_features: int,
  69. bn_weight_init: float = 1,
  70. device=None,
  71. dtype=None,
  72. ):
  73. dd = {'device': device, 'dtype': dtype}
  74. super().__init__()
  75. self.linear = nn.Linear(in_features, out_features, bias=False, **dd)
  76. self.bn = nn.BatchNorm1d(out_features, **dd)
  77. nn.init.constant_(self.bn.weight, bn_weight_init)
  78. @torch.no_grad()
  79. def fuse(self):
  80. l, bn = self.linear, self.bn
  81. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  82. w = l.weight * w[:, None]
  83. b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
  84. m = nn.Linear(w.size(1), w.size(0))
  85. m.weight.data.copy_(w)
  86. m.bias.data.copy_(b)
  87. return m
  88. def forward(self, x):
  89. x = self.linear(x)
  90. return self.bn(x.flatten(0, 1)).reshape_as(x)
  91. class NormLinear(nn.Module):
  92. def __init__(
  93. self,
  94. in_features: int,
  95. out_features: int,
  96. bias: bool = True,
  97. std: float = 0.02,
  98. drop: float = 0.,
  99. device=None,
  100. dtype=None,
  101. ):
  102. dd = {'device': device, 'dtype': dtype}
  103. super().__init__()
  104. self.bn = nn.BatchNorm1d(in_features, **dd)
  105. self.drop = nn.Dropout(drop)
  106. self.linear = nn.Linear(in_features, out_features, bias=bias, **dd)
  107. trunc_normal_(self.linear.weight, std=std)
  108. if self.linear.bias is not None:
  109. nn.init.constant_(self.linear.bias, 0)
  110. @torch.no_grad()
  111. def fuse(self):
  112. bn, l = self.bn, self.linear
  113. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  114. b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
  115. w = l.weight * w[None, :]
  116. if l.bias is None:
  117. b = b @ self.linear.weight.T
  118. else:
  119. b = (l.weight @ b[:, None]).view(-1) + self.linear.bias
  120. m = nn.Linear(w.size(1), w.size(0))
  121. m.weight.data.copy_(w)
  122. m.bias.data.copy_(b)
  123. return m
  124. def forward(self, x):
  125. return self.linear(self.drop(self.bn(x)))
  126. class Stem8(nn.Sequential):
  127. def __init__(
  128. self,
  129. in_chs: int,
  130. out_chs: int,
  131. act_layer: Type[nn.Module],
  132. device=None,
  133. dtype=None,
  134. ):
  135. dd = {'device': device, 'dtype': dtype}
  136. super().__init__()
  137. self.stride = 8
  138. self.add_module('conv1', ConvNorm(in_chs, out_chs // 4, 3, stride=2, padding=1, **dd))
  139. self.add_module('act1', act_layer())
  140. self.add_module('conv2', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1, **dd))
  141. self.add_module('act2', act_layer())
  142. self.add_module('conv3', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1, **dd))
  143. class Stem16(nn.Sequential):
  144. def __init__(
  145. self,
  146. in_chs: int,
  147. out_chs: int,
  148. act_layer: Type[nn.Module],
  149. device=None,
  150. dtype=None,
  151. ):
  152. dd = {'device': device, 'dtype': dtype}
  153. super().__init__()
  154. self.stride = 16
  155. self.add_module('conv1', ConvNorm(in_chs, out_chs // 8, 3, stride=2, padding=1, **dd))
  156. self.add_module('act1', act_layer())
  157. self.add_module('conv2', ConvNorm(out_chs // 8, out_chs // 4, 3, stride=2, padding=1, **dd))
  158. self.add_module('act2', act_layer())
  159. self.add_module('conv3', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1, **dd))
  160. self.add_module('act3', act_layer())
  161. self.add_module('conv4', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1, **dd))
  162. class Downsample(nn.Module):
  163. def __init__(
  164. self,
  165. stride: int,
  166. resolution: Union[int, Tuple[int, int]],
  167. use_pool: bool = False,
  168. device=None,
  169. dtype=None,
  170. ):
  171. super().__init__()
  172. self.stride = stride
  173. self.resolution = to_2tuple(resolution)
  174. self.pool = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) if use_pool else None
  175. def forward(self, x):
  176. B, N, C = x.shape
  177. x = x.view(B, self.resolution[0], self.resolution[1], C)
  178. if self.pool is not None:
  179. x = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
  180. else:
  181. x = x[:, ::self.stride, ::self.stride]
  182. return x.reshape(B, -1, C)
  183. class Attention(nn.Module):
  184. attention_bias_cache: Dict[str, torch.Tensor]
  185. def __init__(
  186. self,
  187. dim: int,
  188. key_dim: int,
  189. num_heads: int = 8,
  190. attn_ratio: float = 4.,
  191. resolution: Union[int, Tuple[int, int]] = 14,
  192. use_conv: bool = False,
  193. act_layer: Type[nn.Module] = nn.SiLU,
  194. device=None,
  195. dtype=None,
  196. ):
  197. dd = {'device': device, 'dtype': dtype}
  198. super().__init__()
  199. ln_layer = ConvNorm if use_conv else LinearNorm
  200. resolution = to_2tuple(resolution)
  201. self.use_conv = use_conv
  202. self.num_heads = num_heads
  203. self.scale = key_dim ** -0.5
  204. self.key_dim = key_dim
  205. self.key_attn_dim = key_dim * num_heads
  206. self.val_dim = int(attn_ratio * key_dim)
  207. self.val_attn_dim = int(attn_ratio * key_dim) * num_heads
  208. self.qkv = ln_layer(dim, self.val_attn_dim + self.key_attn_dim * 2, **dd)
  209. self.proj = nn.Sequential(OrderedDict([
  210. ('act', act_layer()),
  211. ('ln', ln_layer(self.val_attn_dim, dim, bn_weight_init=0, **dd))
  212. ]))
  213. self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1], **dd))
  214. pos = torch.stack(ndgrid(
  215. torch.arange(resolution[0], device=device, dtype=torch.long),
  216. torch.arange(resolution[1], device=device, dtype=torch.long),
  217. )).flatten(1)
  218. rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
  219. rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
  220. self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
  221. self.attention_bias_cache = {}
  222. @torch.no_grad()
  223. def train(self, mode=True):
  224. super().train(mode)
  225. if mode and self.attention_bias_cache:
  226. self.attention_bias_cache = {} # clear ab cache
  227. def get_attention_biases(self, device: torch.device) -> torch.Tensor:
  228. if torch.jit.is_tracing() or self.training:
  229. return self.attention_biases[:, self.attention_bias_idxs]
  230. else:
  231. device_key = str(device)
  232. if device_key not in self.attention_bias_cache:
  233. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  234. return self.attention_bias_cache[device_key]
  235. def forward(self, x): # x (B,C,H,W)
  236. if self.use_conv:
  237. B, C, H, W = x.shape
  238. q, k, v = self.qkv(x).view(
  239. B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.val_dim], dim=2)
  240. attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
  241. attn = attn.softmax(dim=-1)
  242. x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
  243. else:
  244. B, N, C = x.shape
  245. q, k, v = self.qkv(x).view(
  246. B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
  247. q = q.permute(0, 2, 1, 3)
  248. k = k.permute(0, 2, 3, 1)
  249. v = v.permute(0, 2, 1, 3)
  250. attn = q @ k * self.scale + self.get_attention_biases(x.device)
  251. attn = attn.softmax(dim=-1)
  252. x = (attn @ v).transpose(1, 2).reshape(B, N, self.val_attn_dim)
  253. x = self.proj(x)
  254. return x
  255. class AttentionDownsample(nn.Module):
  256. attention_bias_cache: Dict[str, torch.Tensor]
  257. def __init__(
  258. self,
  259. in_dim: int,
  260. out_dim: int,
  261. key_dim: int,
  262. num_heads: int = 8,
  263. attn_ratio: float = 2.0,
  264. stride: int = 2,
  265. resolution: Union[int, Tuple[int, int]] = 14,
  266. use_conv: bool = False,
  267. use_pool: bool = False,
  268. act_layer: Type[nn.Module] = nn.SiLU,
  269. device=None,
  270. dtype=None,
  271. ):
  272. dd = {'device': device, 'dtype': dtype}
  273. super().__init__()
  274. resolution = to_2tuple(resolution)
  275. self.stride = stride
  276. self.resolution = resolution
  277. self.num_heads = num_heads
  278. self.key_dim = key_dim
  279. self.key_attn_dim = key_dim * num_heads
  280. self.val_dim = int(attn_ratio * key_dim)
  281. self.val_attn_dim = self.val_dim * self.num_heads
  282. self.scale = key_dim ** -0.5
  283. self.use_conv = use_conv
  284. if self.use_conv:
  285. ln_layer = ConvNorm
  286. sub_layer = partial(
  287. nn.AvgPool2d,
  288. kernel_size=3 if use_pool else 1, padding=1 if use_pool else 0, count_include_pad=False)
  289. else:
  290. ln_layer = LinearNorm
  291. sub_layer = partial(Downsample, resolution=resolution, use_pool=use_pool, **dd)
  292. self.kv = ln_layer(in_dim, self.val_attn_dim + self.key_attn_dim, **dd)
  293. self.q = nn.Sequential(OrderedDict([
  294. ('down', sub_layer(stride=stride)),
  295. ('ln', ln_layer(in_dim, self.key_attn_dim, **dd))
  296. ]))
  297. self.proj = nn.Sequential(OrderedDict([
  298. ('act', act_layer()),
  299. ('ln', ln_layer(self.val_attn_dim, out_dim, **dd))
  300. ]))
  301. self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1], **dd))
  302. k_pos = torch.stack(ndgrid(
  303. torch.arange(resolution[0], device=device, dtype=torch.long),
  304. torch.arange(resolution[1], device=device, dtype=torch.long),
  305. )).flatten(1)
  306. q_pos = torch.stack(ndgrid(
  307. torch.arange(0, resolution[0], step=stride, device=device, dtype=torch.long),
  308. torch.arange(0, resolution[1], step=stride, device=device, dtype=torch.long),
  309. )).flatten(1)
  310. rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
  311. rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
  312. self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
  313. self.attention_bias_cache = {} # per-device attention_biases cache
  314. @torch.no_grad()
  315. def train(self, mode=True):
  316. super().train(mode)
  317. if mode and self.attention_bias_cache:
  318. self.attention_bias_cache = {} # clear ab cache
  319. def get_attention_biases(self, device: torch.device) -> torch.Tensor:
  320. if torch.jit.is_tracing() or self.training:
  321. return self.attention_biases[:, self.attention_bias_idxs]
  322. else:
  323. device_key = str(device)
  324. if device_key not in self.attention_bias_cache:
  325. self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
  326. return self.attention_bias_cache[device_key]
  327. def forward(self, x):
  328. if self.use_conv:
  329. B, C, H, W = x.shape
  330. HH, WW = (H - 1) // self.stride + 1, (W - 1) // self.stride + 1
  331. k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.val_dim], dim=2)
  332. q = self.q(x).view(B, self.num_heads, self.key_dim, -1)
  333. attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
  334. attn = attn.softmax(dim=-1)
  335. x = (v @ attn.transpose(-2, -1)).reshape(B, self.val_attn_dim, HH, WW)
  336. else:
  337. B, N, C = x.shape
  338. k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.val_dim], dim=3)
  339. k = k.permute(0, 2, 3, 1) # BHCN
  340. v = v.permute(0, 2, 1, 3) # BHNC
  341. q = self.q(x).view(B, -1, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
  342. attn = q @ k * self.scale + self.get_attention_biases(x.device)
  343. attn = attn.softmax(dim=-1)
  344. x = (attn @ v).transpose(1, 2).reshape(B, -1, self.val_attn_dim)
  345. x = self.proj(x)
  346. return x
  347. class LevitMlp(nn.Module):
  348. """ MLP for Levit w/ normalization + ability to switch btw conv and linear
  349. """
  350. def __init__(
  351. self,
  352. in_features: int,
  353. hidden_features: Optional[int] = None,
  354. out_features: Optional[int] = None,
  355. use_conv: bool = False,
  356. act_layer: Type[nn.Module] = nn.SiLU,
  357. drop: float = 0.,
  358. device=None,
  359. dtype=None,
  360. ):
  361. dd = {'device': device, 'dtype': dtype}
  362. super().__init__()
  363. out_features = out_features or in_features
  364. hidden_features = hidden_features or in_features
  365. ln_layer = ConvNorm if use_conv else LinearNorm
  366. self.ln1 = ln_layer(in_features, hidden_features, **dd)
  367. self.act = act_layer()
  368. self.drop = nn.Dropout(drop)
  369. self.ln2 = ln_layer(hidden_features, out_features, bn_weight_init=0, **dd)
  370. def forward(self, x):
  371. x = self.ln1(x)
  372. x = self.act(x)
  373. x = self.drop(x)
  374. x = self.ln2(x)
  375. return x
  376. class LevitDownsample(nn.Module):
  377. def __init__(
  378. self,
  379. in_dim: int,
  380. out_dim: int,
  381. key_dim: int,
  382. num_heads: int = 8,
  383. attn_ratio: float = 4.,
  384. mlp_ratio: float = 2.,
  385. act_layer: Type[nn.Module] = nn.SiLU,
  386. attn_act_layer: Optional[Type[nn.Module]] = None,
  387. resolution: Union[int, Tuple[int, int]] = 14,
  388. use_conv: bool = False,
  389. use_pool: bool = False,
  390. drop_path: float = 0.,
  391. device=None,
  392. dtype=None,
  393. ):
  394. dd = {'device': device, 'dtype': dtype}
  395. super().__init__()
  396. attn_act_layer = attn_act_layer or act_layer
  397. self.attn_downsample = AttentionDownsample(
  398. in_dim=in_dim,
  399. out_dim=out_dim,
  400. key_dim=key_dim,
  401. num_heads=num_heads,
  402. attn_ratio=attn_ratio,
  403. act_layer=attn_act_layer,
  404. resolution=resolution,
  405. use_conv=use_conv,
  406. use_pool=use_pool,
  407. **dd,
  408. )
  409. self.mlp = LevitMlp(
  410. out_dim,
  411. int(out_dim * mlp_ratio),
  412. use_conv=use_conv,
  413. act_layer=act_layer,
  414. **dd,
  415. )
  416. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  417. def forward(self, x):
  418. x = self.attn_downsample(x)
  419. x = x + self.drop_path(self.mlp(x))
  420. return x
  421. class LevitBlock(nn.Module):
  422. def __init__(
  423. self,
  424. dim: int,
  425. key_dim: int,
  426. num_heads: int = 8,
  427. attn_ratio: float = 4.,
  428. mlp_ratio: float = 2.,
  429. resolution: Union[int, Tuple[int, int]] = 14,
  430. use_conv: bool = False,
  431. act_layer: Type[nn.Module] = nn.SiLU,
  432. attn_act_layer: Optional[Type[nn.Module]] = None,
  433. drop_path: float = 0.,
  434. device=None,
  435. dtype=None,
  436. ):
  437. dd = {'device': device, 'dtype': dtype}
  438. super().__init__()
  439. attn_act_layer = attn_act_layer or act_layer
  440. self.attn = Attention(
  441. dim=dim,
  442. key_dim=key_dim,
  443. num_heads=num_heads,
  444. attn_ratio=attn_ratio,
  445. resolution=resolution,
  446. use_conv=use_conv,
  447. act_layer=attn_act_layer,
  448. **dd,
  449. )
  450. self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  451. self.mlp = LevitMlp(
  452. dim,
  453. int(dim * mlp_ratio),
  454. use_conv=use_conv,
  455. act_layer=act_layer,
  456. **dd,
  457. )
  458. self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  459. def forward(self, x):
  460. x = x + self.drop_path1(self.attn(x))
  461. x = x + self.drop_path2(self.mlp(x))
  462. return x
  463. class LevitStage(nn.Module):
  464. def __init__(
  465. self,
  466. in_dim: int,
  467. out_dim: int,
  468. key_dim: int,
  469. depth: int = 4,
  470. num_heads: int = 8,
  471. attn_ratio: float = 4.0,
  472. mlp_ratio: float = 4.0,
  473. act_layer: Type[nn.Module] = nn.SiLU,
  474. attn_act_layer: Optional[Type[nn.Module]] = None,
  475. resolution: Union[int, Tuple[int, int]] = 14,
  476. downsample: str = '',
  477. use_conv: bool = False,
  478. drop_path: float = 0.,
  479. device=None,
  480. dtype=None,
  481. ):
  482. dd = {'device': device, 'dtype': dtype}
  483. super().__init__()
  484. resolution = to_2tuple(resolution)
  485. if downsample:
  486. self.downsample = LevitDownsample(
  487. in_dim,
  488. out_dim,
  489. key_dim=key_dim,
  490. num_heads=in_dim // key_dim,
  491. attn_ratio=4.,
  492. mlp_ratio=2.,
  493. act_layer=act_layer,
  494. attn_act_layer=attn_act_layer,
  495. resolution=resolution,
  496. use_conv=use_conv,
  497. drop_path=drop_path,
  498. **dd,
  499. )
  500. resolution = [(r - 1) // 2 + 1 for r in resolution]
  501. else:
  502. assert in_dim == out_dim
  503. self.downsample = nn.Identity()
  504. blocks = []
  505. for _ in range(depth):
  506. blocks += [LevitBlock(
  507. out_dim,
  508. key_dim,
  509. num_heads=num_heads,
  510. attn_ratio=attn_ratio,
  511. mlp_ratio=mlp_ratio,
  512. act_layer=act_layer,
  513. attn_act_layer=attn_act_layer,
  514. resolution=resolution,
  515. use_conv=use_conv,
  516. drop_path=drop_path,
  517. **dd,
  518. )]
  519. self.blocks = nn.Sequential(*blocks)
  520. def forward(self, x):
  521. x = self.downsample(x)
  522. x = self.blocks(x)
  523. return x
  524. class Levit(nn.Module):
  525. """ Vision Transformer with support for patch or hybrid CNN input stage
  526. NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems
  527. w/ train scripts that don't take tuple outputs,
  528. """
  529. def __init__(
  530. self,
  531. img_size: Union[int, Tuple[int, int]] = 224,
  532. in_chans: int = 3,
  533. num_classes: int = 1000,
  534. embed_dim: Tuple[int, ...] = (192,),
  535. key_dim: int = 64,
  536. depth: Tuple[int, ...] = (12,),
  537. num_heads: Union[int, Tuple[int, ...]] = (3,),
  538. attn_ratio: Union[float, Tuple[float, ...]] = 2.,
  539. mlp_ratio: Union[float, Tuple[float, ...]] = 2.,
  540. stem_backbone: Optional[nn.Module] = None,
  541. stem_stride: Optional[int] = None,
  542. stem_type: str = 's16',
  543. down_op: str = 'subsample',
  544. act_layer: str = 'hard_swish',
  545. attn_act_layer: Optional[str] = None,
  546. use_conv: bool = False,
  547. global_pool: str = 'avg',
  548. drop_rate: float = 0.,
  549. drop_path_rate: float = 0.,
  550. device=None,
  551. dtype=None,
  552. ):
  553. super().__init__()
  554. dd = {'device': device, 'dtype': dtype}
  555. act_layer = get_act_layer(act_layer)
  556. attn_act_layer = get_act_layer(attn_act_layer or act_layer)
  557. self.use_conv = use_conv
  558. self.num_classes = num_classes
  559. self.global_pool = global_pool
  560. self.num_features = self.head_hidden_size = embed_dim[-1]
  561. self.embed_dim = embed_dim
  562. self.drop_rate = drop_rate
  563. self.grad_checkpointing = False
  564. self.feature_info = []
  565. num_stages = len(embed_dim)
  566. assert len(depth) == num_stages
  567. num_heads = to_ntuple(num_stages)(num_heads)
  568. attn_ratio = to_ntuple(num_stages)(attn_ratio)
  569. mlp_ratio = to_ntuple(num_stages)(mlp_ratio)
  570. if stem_backbone is not None:
  571. assert stem_stride >= 2
  572. self.stem = stem_backbone
  573. stride = stem_stride
  574. else:
  575. assert stem_type in ('s16', 's8')
  576. if stem_type == 's16':
  577. self.stem = Stem16(in_chans, embed_dim[0], act_layer=act_layer, **dd)
  578. else:
  579. self.stem = Stem8(in_chans, embed_dim[0], act_layer=act_layer, **dd)
  580. stride = self.stem.stride
  581. resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))])
  582. in_dim = embed_dim[0]
  583. stages = []
  584. for i in range(num_stages):
  585. stage_stride = 2 if i > 0 else 1
  586. stages += [LevitStage(
  587. in_dim,
  588. embed_dim[i],
  589. key_dim,
  590. depth=depth[i],
  591. num_heads=num_heads[i],
  592. attn_ratio=attn_ratio[i],
  593. mlp_ratio=mlp_ratio[i],
  594. act_layer=act_layer,
  595. attn_act_layer=attn_act_layer,
  596. resolution=resolution,
  597. use_conv=use_conv,
  598. downsample=down_op if stage_stride == 2 else '',
  599. drop_path=drop_path_rate,
  600. **dd,
  601. )]
  602. stride *= stage_stride
  603. resolution = tuple([(r - 1) // stage_stride + 1 for r in resolution])
  604. self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')]
  605. in_dim = embed_dim[i]
  606. self.stages = nn.Sequential(*stages)
  607. # Classifier head
  608. self.head = NormLinear(embed_dim[-1], num_classes, drop=drop_rate, **dd) if num_classes > 0 else nn.Identity()
  609. @torch.jit.ignore
  610. def no_weight_decay(self):
  611. return {x for x in self.state_dict().keys() if 'attention_biases' in x}
  612. @torch.jit.ignore
  613. def group_matcher(self, coarse=False):
  614. matcher = dict(
  615. stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
  616. blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
  617. )
  618. return matcher
  619. @torch.jit.ignore
  620. def set_grad_checkpointing(self, enable=True):
  621. self.grad_checkpointing = enable
  622. @torch.jit.ignore
  623. def get_classifier(self) -> nn.Module:
  624. return self.head
  625. def reset_classifier(self, num_classes: int , global_pool: Optional[str] = None):
  626. self.num_classes = num_classes
  627. if global_pool is not None:
  628. self.global_pool = global_pool
  629. self.head = NormLinear(
  630. self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity()
  631. def forward_intermediates(
  632. self,
  633. x: torch.Tensor,
  634. indices: Optional[Union[int, List[int]]] = None,
  635. norm: bool = False,
  636. stop_early: bool = False,
  637. output_fmt: str = 'NCHW',
  638. intermediates_only: bool = False,
  639. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  640. """ Forward features that returns intermediates.
  641. Args:
  642. x: Input image tensor
  643. indices: Take last n blocks if int, all if None, select matching indices if sequence
  644. norm: Apply norm layer to compatible intermediates
  645. stop_early: Stop iterating over blocks when last desired intermediate hit
  646. output_fmt: Shape of intermediate feature outputs
  647. intermediates_only: Only return intermediate features
  648. Returns:
  649. """
  650. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  651. intermediates = []
  652. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  653. # forward pass
  654. x = self.stem(x)
  655. B, C, H, W = x.shape
  656. if not self.use_conv:
  657. x = x.flatten(2).transpose(1, 2)
  658. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  659. stages = self.stages
  660. else:
  661. stages = self.stages[:max_index + 1]
  662. for feat_idx, stage in enumerate(stages):
  663. if self.grad_checkpointing and not torch.jit.is_scripting():
  664. x = checkpoint(stage, x)
  665. else:
  666. x = stage(x)
  667. if feat_idx in take_indices:
  668. if self.use_conv:
  669. intermediates.append(x)
  670. else:
  671. intermediates.append(x.reshape(B, H, W, -1).permute(0, 3, 1, 2))
  672. H = (H + 2 - 1) // 2
  673. W = (W + 2 - 1) // 2
  674. if intermediates_only:
  675. return intermediates
  676. return x, intermediates
  677. def prune_intermediate_layers(
  678. self,
  679. indices: Union[int, List[int]] = 1,
  680. prune_norm: bool = False,
  681. prune_head: bool = True,
  682. ):
  683. """ Prune layers not required for specified intermediates.
  684. """
  685. take_indices, max_index = feature_take_indices(len(self.stages), indices)
  686. self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  687. if prune_head:
  688. self.reset_classifier(0, '')
  689. return take_indices
  690. def forward_features(self, x):
  691. x = self.stem(x)
  692. if not self.use_conv:
  693. x = x.flatten(2).transpose(1, 2)
  694. if self.grad_checkpointing and not torch.jit.is_scripting():
  695. x = checkpoint_seq(self.stages, x)
  696. else:
  697. x = self.stages(x)
  698. return x
  699. def forward_head(self, x, pre_logits: bool = False):
  700. if self.global_pool == 'avg':
  701. x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
  702. return x if pre_logits else self.head(x)
  703. def forward(self, x):
  704. x = self.forward_features(x)
  705. x = self.forward_head(x)
  706. return x
  707. class LevitDistilled(Levit):
  708. def __init__(self, *args, **kwargs):
  709. super().__init__(*args, **kwargs)
  710. dd = {'device': kwargs.get('device', None), 'dtype': kwargs.get('dtype', None)}
  711. self.head_dist = NormLinear(self.num_features, self.num_classes, **dd) if self.num_classes > 0 else nn.Identity()
  712. self.distilled_training = False # must set this True to train w/ distillation token
  713. @torch.jit.ignore
  714. def get_classifier(self) -> nn.Module:
  715. return self.head, self.head_dist
  716. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  717. self.num_classes = num_classes
  718. if global_pool is not None:
  719. self.global_pool = global_pool
  720. self.head = NormLinear(
  721. self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity()
  722. self.head_dist = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  723. @torch.jit.ignore
  724. def set_distilled_training(self, enable=True):
  725. self.distilled_training = enable
  726. def forward_head(self, x, pre_logits: bool = False):
  727. if self.global_pool == 'avg':
  728. x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
  729. if pre_logits:
  730. return x
  731. x, x_dist = self.head(x), self.head_dist(x)
  732. if self.distilled_training and self.training and not torch.jit.is_scripting():
  733. # only return separate classification predictions when training in distilled mode
  734. return x, x_dist
  735. else:
  736. # during standard train/finetune, inference average the classifier predictions
  737. return (x + x_dist) / 2
  738. def checkpoint_filter_fn(state_dict, model):
  739. if 'model' in state_dict:
  740. state_dict = state_dict['model']
  741. # filter out attn biases, should not have been persistent
  742. state_dict = {k: v for k, v in state_dict.items() if 'attention_bias_idxs' not in k}
  743. # NOTE: old weight conversion code, disabled
  744. # D = model.state_dict()
  745. # out_dict = {}
  746. # for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
  747. # if va.ndim == 4 and vb.ndim == 2:
  748. # vb = vb[:, :, None, None]
  749. # if va.shape != vb.shape:
  750. # # head or first-conv shapes may change for fine-tune
  751. # assert 'head' in ka or 'stem.conv1.linear' in ka
  752. # out_dict[ka] = vb
  753. return state_dict
  754. model_cfgs = dict(
  755. levit_128s=dict(
  756. embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)),
  757. levit_128=dict(
  758. embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)),
  759. levit_192=dict(
  760. embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)),
  761. levit_256=dict(
  762. embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)),
  763. levit_384=dict(
  764. embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)),
  765. # stride-8 stem experiments
  766. levit_384_s8=dict(
  767. embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4),
  768. act_layer='silu', stem_type='s8'),
  769. levit_512_s8=dict(
  770. embed_dim=(512, 640, 896), key_dim=64, num_heads=(8, 10, 14), depth=(4, 4, 4),
  771. act_layer='silu', stem_type='s8'),
  772. # wider experiments
  773. levit_512=dict(
  774. embed_dim=(512, 768, 1024), key_dim=64, num_heads=(8, 12, 16), depth=(4, 4, 4), act_layer='silu'),
  775. # deeper experiments
  776. levit_256d=dict(
  777. embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 8, 6), act_layer='silu'),
  778. levit_512d=dict(
  779. embed_dim=(512, 640, 768), key_dim=64, num_heads=(8, 10, 12), depth=(4, 8, 6), act_layer='silu'),
  780. )
  781. def create_levit(variant, cfg_variant=None, pretrained=False, distilled=True, **kwargs):
  782. is_conv = '_conv' in variant
  783. out_indices = kwargs.pop('out_indices', (0, 1, 2))
  784. if kwargs.get('features_only', False) and not is_conv:
  785. kwargs.setdefault('feature_cls', 'getter')
  786. if cfg_variant is None:
  787. if variant in model_cfgs:
  788. cfg_variant = variant
  789. elif is_conv:
  790. cfg_variant = variant.replace('_conv', '')
  791. model_cfg = dict(model_cfgs[cfg_variant], **kwargs)
  792. model = build_model_with_cfg(
  793. LevitDistilled if distilled else Levit,
  794. variant,
  795. pretrained,
  796. pretrained_filter_fn=checkpoint_filter_fn,
  797. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  798. **model_cfg,
  799. )
  800. return model
  801. def _cfg(url='', **kwargs):
  802. return {
  803. 'url': url,
  804. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  805. 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
  806. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  807. 'first_conv': 'stem.conv1.linear', 'classifier': ('head.linear', 'head_dist.linear'),
  808. 'license': 'apache-2.0',
  809. **kwargs
  810. }
  811. default_cfgs = generate_default_cfgs({
  812. # weights in nn.Linear mode
  813. 'levit_128s.fb_dist_in1k': _cfg(
  814. hf_hub_id='timm/',
  815. ),
  816. 'levit_128.fb_dist_in1k': _cfg(
  817. hf_hub_id='timm/',
  818. ),
  819. 'levit_192.fb_dist_in1k': _cfg(
  820. hf_hub_id='timm/',
  821. ),
  822. 'levit_256.fb_dist_in1k': _cfg(
  823. hf_hub_id='timm/',
  824. ),
  825. 'levit_384.fb_dist_in1k': _cfg(
  826. hf_hub_id='timm/',
  827. ),
  828. # weights in nn.Conv2d mode
  829. 'levit_conv_128s.fb_dist_in1k': _cfg(
  830. hf_hub_id='timm/',
  831. pool_size=(4, 4),
  832. ),
  833. 'levit_conv_128.fb_dist_in1k': _cfg(
  834. hf_hub_id='timm/',
  835. pool_size=(4, 4),
  836. ),
  837. 'levit_conv_192.fb_dist_in1k': _cfg(
  838. hf_hub_id='timm/',
  839. pool_size=(4, 4),
  840. ),
  841. 'levit_conv_256.fb_dist_in1k': _cfg(
  842. hf_hub_id='timm/',
  843. pool_size=(4, 4),
  844. ),
  845. 'levit_conv_384.fb_dist_in1k': _cfg(
  846. hf_hub_id='timm/',
  847. pool_size=(4, 4),
  848. ),
  849. 'levit_384_s8.untrained': _cfg(classifier='head.linear'),
  850. 'levit_512_s8.untrained': _cfg(classifier='head.linear'),
  851. 'levit_512.untrained': _cfg(classifier='head.linear'),
  852. 'levit_256d.untrained': _cfg(classifier='head.linear'),
  853. 'levit_512d.untrained': _cfg(classifier='head.linear'),
  854. 'levit_conv_384_s8.untrained': _cfg(classifier='head.linear'),
  855. 'levit_conv_512_s8.untrained': _cfg(classifier='head.linear'),
  856. 'levit_conv_512.untrained': _cfg(classifier='head.linear'),
  857. 'levit_conv_256d.untrained': _cfg(classifier='head.linear'),
  858. 'levit_conv_512d.untrained': _cfg(classifier='head.linear'),
  859. })
  860. @register_model
  861. def levit_128s(pretrained=False, **kwargs) -> Levit:
  862. return create_levit('levit_128s', pretrained=pretrained, **kwargs)
  863. @register_model
  864. def levit_128(pretrained=False, **kwargs) -> Levit:
  865. return create_levit('levit_128', pretrained=pretrained, **kwargs)
  866. @register_model
  867. def levit_192(pretrained=False, **kwargs) -> Levit:
  868. return create_levit('levit_192', pretrained=pretrained, **kwargs)
  869. @register_model
  870. def levit_256(pretrained=False, **kwargs) -> Levit:
  871. return create_levit('levit_256', pretrained=pretrained, **kwargs)
  872. @register_model
  873. def levit_384(pretrained=False, **kwargs) -> Levit:
  874. return create_levit('levit_384', pretrained=pretrained, **kwargs)
  875. @register_model
  876. def levit_384_s8(pretrained=False, **kwargs) -> Levit:
  877. return create_levit('levit_384_s8', pretrained=pretrained, **kwargs)
  878. @register_model
  879. def levit_512_s8(pretrained=False, **kwargs) -> Levit:
  880. return create_levit('levit_512_s8', pretrained=pretrained, distilled=False, **kwargs)
  881. @register_model
  882. def levit_512(pretrained=False, **kwargs) -> Levit:
  883. return create_levit('levit_512', pretrained=pretrained, distilled=False, **kwargs)
  884. @register_model
  885. def levit_256d(pretrained=False, **kwargs) -> Levit:
  886. return create_levit('levit_256d', pretrained=pretrained, distilled=False, **kwargs)
  887. @register_model
  888. def levit_512d(pretrained=False, **kwargs) -> Levit:
  889. return create_levit('levit_512d', pretrained=pretrained, distilled=False, **kwargs)
  890. @register_model
  891. def levit_conv_128s(pretrained=False, **kwargs) -> Levit:
  892. return create_levit('levit_conv_128s', pretrained=pretrained, use_conv=True, **kwargs)
  893. @register_model
  894. def levit_conv_128(pretrained=False, **kwargs) -> Levit:
  895. return create_levit('levit_conv_128', pretrained=pretrained, use_conv=True, **kwargs)
  896. @register_model
  897. def levit_conv_192(pretrained=False, **kwargs) -> Levit:
  898. return create_levit('levit_conv_192', pretrained=pretrained, use_conv=True, **kwargs)
  899. @register_model
  900. def levit_conv_256(pretrained=False, **kwargs) -> Levit:
  901. return create_levit('levit_conv_256', pretrained=pretrained, use_conv=True, **kwargs)
  902. @register_model
  903. def levit_conv_384(pretrained=False, **kwargs) -> Levit:
  904. return create_levit('levit_conv_384', pretrained=pretrained, use_conv=True, **kwargs)
  905. @register_model
  906. def levit_conv_384_s8(pretrained=False, **kwargs) -> Levit:
  907. return create_levit('levit_conv_384_s8', pretrained=pretrained, use_conv=True, **kwargs)
  908. @register_model
  909. def levit_conv_512_s8(pretrained=False, **kwargs) -> Levit:
  910. return create_levit('levit_conv_512_s8', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
  911. @register_model
  912. def levit_conv_512(pretrained=False, **kwargs) -> Levit:
  913. return create_levit('levit_conv_512', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
  914. @register_model
  915. def levit_conv_256d(pretrained=False, **kwargs) -> Levit:
  916. return create_levit('levit_conv_256d', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
  917. @register_model
  918. def levit_conv_512d(pretrained=False, **kwargs) -> Levit:
  919. return create_levit('levit_conv_512d', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)