sequencer.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. """ Sequencer
  2. Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2205.01972.pdf
  3. """
  4. # Copyright (c) 2022. Yuki Tatsunami
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. import math
  7. from functools import partial
  8. from itertools import accumulate
  9. from typing import List, Optional, Tuple, Type, Union
  10. import torch
  11. import torch.nn as nn
  12. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
  13. from timm.layers import lecun_normal_, DropPath, Mlp, PatchEmbed, ClassifierHead
  14. from ._builder import build_model_with_cfg
  15. from ._manipulate import named_apply
  16. from ._registry import register_model, generate_default_cfgs
  17. __all__ = ['Sequencer2d'] # model_registry will add each entrypoint fn to this
  18. def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False):
  19. if isinstance(module, nn.Linear):
  20. if name.startswith('head'):
  21. nn.init.zeros_(module.weight)
  22. nn.init.constant_(module.bias, head_bias)
  23. else:
  24. if flax:
  25. # Flax defaults
  26. lecun_normal_(module.weight)
  27. if module.bias is not None:
  28. nn.init.zeros_(module.bias)
  29. else:
  30. nn.init.xavier_uniform_(module.weight)
  31. if module.bias is not None:
  32. if 'mlp' in name:
  33. nn.init.normal_(module.bias, std=1e-6)
  34. else:
  35. nn.init.zeros_(module.bias)
  36. elif isinstance(module, nn.Conv2d):
  37. lecun_normal_(module.weight)
  38. if module.bias is not None:
  39. nn.init.zeros_(module.bias)
  40. elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
  41. nn.init.ones_(module.weight)
  42. nn.init.zeros_(module.bias)
  43. elif isinstance(module, (nn.RNN, nn.GRU, nn.LSTM)):
  44. stdv = 1.0 / math.sqrt(module.hidden_size)
  45. for weight in module.parameters():
  46. nn.init.uniform_(weight, -stdv, stdv)
  47. elif hasattr(module, 'init_weights'):
  48. module.init_weights()
  49. class RNNIdentity(nn.Module):
  50. def __init__(self, *args, **kwargs):
  51. super().__init__()
  52. def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]:
  53. return x, None
  54. class RNN2dBase(nn.Module):
  55. def __init__(
  56. self,
  57. input_size: int,
  58. hidden_size: int,
  59. num_layers: int = 1,
  60. bias: bool = True,
  61. bidirectional: bool = True,
  62. union: str = "cat",
  63. with_fc: bool = True,
  64. device=None,
  65. dtype=None,
  66. ):
  67. dd = {'device': device, 'dtype': dtype}
  68. super().__init__()
  69. self.input_size = input_size
  70. self.hidden_size = hidden_size
  71. self.output_size = 2 * hidden_size if bidirectional else hidden_size
  72. self.union = union
  73. self.with_vertical = True
  74. self.with_horizontal = True
  75. self.with_fc = with_fc
  76. self.fc = None
  77. if with_fc:
  78. if union == "cat":
  79. self.fc = nn.Linear(2 * self.output_size, input_size, **dd)
  80. elif union == "add":
  81. self.fc = nn.Linear(self.output_size, input_size, **dd)
  82. elif union == "vertical":
  83. self.fc = nn.Linear(self.output_size, input_size, **dd)
  84. self.with_horizontal = False
  85. elif union == "horizontal":
  86. self.fc = nn.Linear(self.output_size, input_size, **dd)
  87. self.with_vertical = False
  88. else:
  89. raise ValueError("Unrecognized union: " + union)
  90. elif union == "cat":
  91. pass
  92. if 2 * self.output_size != input_size:
  93. raise ValueError(f"The output channel {2 * self.output_size} is different from the input channel {input_size}.")
  94. elif union == "add":
  95. pass
  96. if self.output_size != input_size:
  97. raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.")
  98. elif union == "vertical":
  99. if self.output_size != input_size:
  100. raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.")
  101. self.with_horizontal = False
  102. elif union == "horizontal":
  103. if self.output_size != input_size:
  104. raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.")
  105. self.with_vertical = False
  106. else:
  107. raise ValueError("Unrecognized union: " + union)
  108. self.rnn_v = RNNIdentity()
  109. self.rnn_h = RNNIdentity()
  110. def forward(self, x):
  111. B, H, W, C = x.shape
  112. if self.with_vertical:
  113. v = x.permute(0, 2, 1, 3)
  114. v = v.reshape(-1, H, C)
  115. v, _ = self.rnn_v(v)
  116. v = v.reshape(B, W, H, -1)
  117. v = v.permute(0, 2, 1, 3)
  118. else:
  119. v = None
  120. if self.with_horizontal:
  121. h = x.reshape(-1, W, C)
  122. h, _ = self.rnn_h(h)
  123. h = h.reshape(B, H, W, -1)
  124. else:
  125. h = None
  126. if v is not None and h is not None:
  127. if self.union == "cat":
  128. x = torch.cat([v, h], dim=-1)
  129. else:
  130. x = v + h
  131. elif v is not None:
  132. x = v
  133. elif h is not None:
  134. x = h
  135. if self.fc is not None:
  136. x = self.fc(x)
  137. return x
  138. class LSTM2d(RNN2dBase):
  139. def __init__(
  140. self,
  141. input_size: int,
  142. hidden_size: int,
  143. num_layers: int = 1,
  144. bias: bool = True,
  145. bidirectional: bool = True,
  146. union: str = "cat",
  147. with_fc: bool = True,
  148. device=None,
  149. dtype=None,
  150. ):
  151. dd = {'device': device, 'dtype': dtype}
  152. super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc, device, dtype)
  153. if self.with_vertical:
  154. self.rnn_v = nn.LSTM(
  155. input_size,
  156. hidden_size,
  157. num_layers,
  158. batch_first=True,
  159. bias=bias,
  160. bidirectional=bidirectional,
  161. **dd,
  162. )
  163. if self.with_horizontal:
  164. self.rnn_h = nn.LSTM(
  165. input_size,
  166. hidden_size,
  167. num_layers,
  168. batch_first=True,
  169. bias=bias,
  170. bidirectional=bidirectional,
  171. **dd,
  172. )
  173. class Sequencer2dBlock(nn.Module):
  174. def __init__(
  175. self,
  176. dim: int,
  177. hidden_size: int,
  178. mlp_ratio: float = 3.0,
  179. rnn_layer: Type[nn.Module] = LSTM2d,
  180. mlp_layer: Type[nn.Module] = Mlp,
  181. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  182. act_layer: Type[nn.Module] = nn.GELU,
  183. num_layers: int = 1,
  184. bidirectional: bool = True,
  185. union: str = "cat",
  186. with_fc: bool = True,
  187. drop: float = 0.,
  188. drop_path: float = 0.,
  189. device=None,
  190. dtype=None,
  191. ):
  192. dd = {'device': device, 'dtype': dtype}
  193. super().__init__()
  194. channels_dim = int(mlp_ratio * dim)
  195. self.norm1 = norm_layer(dim, **dd)
  196. self.rnn_tokens = rnn_layer(
  197. dim,
  198. hidden_size,
  199. num_layers=num_layers,
  200. bidirectional=bidirectional,
  201. union=union,
  202. with_fc=with_fc,
  203. **dd,
  204. )
  205. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  206. self.norm2 = norm_layer(dim, **dd)
  207. self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop, **dd)
  208. def forward(self, x):
  209. x = x + self.drop_path(self.rnn_tokens(self.norm1(x)))
  210. x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
  211. return x
  212. class Shuffle(nn.Module):
  213. def __init__(self):
  214. super().__init__()
  215. def forward(self, x):
  216. if self.training:
  217. B, H, W, C = x.shape
  218. r = torch.randperm(H * W)
  219. x = x.reshape(B, -1, C)
  220. x = x[:, r, :].reshape(B, H, W, -1)
  221. return x
  222. class Downsample2d(nn.Module):
  223. def __init__(
  224. self,
  225. input_dim: int,
  226. output_dim: int,
  227. patch_size: int,
  228. device=None,
  229. dtype=None,
  230. ):
  231. dd = {'device': device, 'dtype': dtype}
  232. super().__init__()
  233. self.down = nn.Conv2d(input_dim, output_dim, kernel_size=patch_size, stride=patch_size, **dd)
  234. def forward(self, x):
  235. x = x.permute(0, 3, 1, 2)
  236. x = self.down(x)
  237. x = x.permute(0, 2, 3, 1)
  238. return x
  239. class Sequencer2dStage(nn.Module):
  240. def __init__(
  241. self,
  242. dim: int,
  243. dim_out: int,
  244. depth: int,
  245. patch_size: int,
  246. hidden_size: int,
  247. mlp_ratio: float,
  248. downsample: bool = False,
  249. block_layer: Type[nn.Module] = Sequencer2dBlock,
  250. rnn_layer: Type[nn.Module] = LSTM2d,
  251. mlp_layer: Type[nn.Module] = Mlp,
  252. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  253. act_layer: Type[nn.Module] = nn.GELU,
  254. num_layers: int = 1,
  255. bidirectional: bool = True,
  256. union: str = "cat",
  257. with_fc: bool = True,
  258. drop: float = 0.,
  259. drop_path: Union[float, List[float]] = 0.,
  260. device=None,
  261. dtype=None,
  262. ):
  263. super().__init__()
  264. dd = {'device': device, 'dtype': dtype}
  265. if downsample:
  266. self.downsample = Downsample2d(dim, dim_out, patch_size, **dd)
  267. else:
  268. assert dim == dim_out
  269. self.downsample = nn.Identity()
  270. blocks = []
  271. for block_idx in range(depth):
  272. blocks.append(block_layer(
  273. dim_out,
  274. hidden_size,
  275. mlp_ratio=mlp_ratio,
  276. rnn_layer=rnn_layer,
  277. mlp_layer=mlp_layer,
  278. norm_layer=norm_layer,
  279. act_layer=act_layer,
  280. num_layers=num_layers,
  281. bidirectional=bidirectional,
  282. union=union,
  283. with_fc=with_fc,
  284. drop=drop,
  285. drop_path=drop_path[block_idx] if isinstance(drop_path, (list, tuple)) else drop_path,
  286. **dd,
  287. ))
  288. self.blocks = nn.Sequential(*blocks)
  289. def forward(self, x):
  290. x = self.downsample(x)
  291. x = self.blocks(x)
  292. return x
  293. class Sequencer2d(nn.Module):
  294. def __init__(
  295. self,
  296. num_classes: int = 1000,
  297. img_size: int = 224,
  298. in_chans: int = 3,
  299. global_pool: str = 'avg',
  300. layers: Tuple[int, ...] = (4, 3, 8, 3),
  301. patch_sizes: Tuple[int, ...] = (7, 2, 2, 1),
  302. embed_dims: Tuple[int, ...] = (192, 384, 384, 384),
  303. hidden_sizes: Tuple[int, ...] = (48, 96, 96, 96),
  304. mlp_ratios: Tuple[float, ...] = (3.0, 3.0, 3.0, 3.0),
  305. block_layer: Type[nn.Module] = Sequencer2dBlock,
  306. rnn_layer: Type[nn.Module] = LSTM2d,
  307. mlp_layer: Type[nn.Module] = Mlp,
  308. norm_layer: Type[nn.Module] = partial(nn.LayerNorm, eps=1e-6),
  309. act_layer: Type[nn.Module] = nn.GELU,
  310. num_rnn_layers: int = 1,
  311. bidirectional: bool = True,
  312. union: str = "cat",
  313. with_fc: bool = True,
  314. drop_rate: float = 0.,
  315. drop_path_rate: float = 0.,
  316. nlhb: bool = False,
  317. stem_norm: bool = False,
  318. device=None,
  319. dtype=None,
  320. ):
  321. super().__init__()
  322. dd = {'device': device, 'dtype': dtype}
  323. assert global_pool in ('', 'avg')
  324. self.num_classes = num_classes
  325. self.global_pool = global_pool
  326. self.num_features = self.head_hidden_size = embed_dims[-1] # for consistency with other models
  327. self.feature_dim = -1 # channel dim index for feature outputs (rank 4, NHWC)
  328. self.output_fmt = 'NHWC'
  329. self.feature_info = []
  330. self.stem = PatchEmbed(
  331. img_size=None,
  332. patch_size=patch_sizes[0],
  333. in_chans=in_chans,
  334. embed_dim=embed_dims[0],
  335. norm_layer=norm_layer if stem_norm else None,
  336. flatten=False,
  337. output_fmt='NHWC',
  338. **dd,
  339. )
  340. assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios)
  341. reductions = list(accumulate(patch_sizes, lambda x, y: x * y))
  342. stages = []
  343. prev_dim = embed_dims[0]
  344. for i, _ in enumerate(embed_dims):
  345. stages += [Sequencer2dStage(
  346. prev_dim,
  347. embed_dims[i],
  348. depth=layers[i],
  349. downsample=i > 0,
  350. patch_size=patch_sizes[i],
  351. hidden_size=hidden_sizes[i],
  352. mlp_ratio=mlp_ratios[i],
  353. block_layer=block_layer,
  354. rnn_layer=rnn_layer,
  355. mlp_layer=mlp_layer,
  356. norm_layer=norm_layer,
  357. act_layer=act_layer,
  358. num_layers=num_rnn_layers,
  359. bidirectional=bidirectional,
  360. union=union,
  361. with_fc=with_fc,
  362. drop=drop_rate,
  363. drop_path=drop_path_rate,
  364. **dd,
  365. )]
  366. prev_dim = embed_dims[i]
  367. self.feature_info += [dict(num_chs=prev_dim, reduction=reductions[i], module=f'stages.{i}')]
  368. self.stages = nn.Sequential(*stages)
  369. self.norm = norm_layer(embed_dims[-1], **dd)
  370. self.head = ClassifierHead(
  371. self.num_features,
  372. num_classes,
  373. pool_type=global_pool,
  374. drop_rate=drop_rate,
  375. input_fmt=self.output_fmt,
  376. **dd,
  377. )
  378. self.init_weights(nlhb=nlhb)
  379. def init_weights(self, nlhb=False):
  380. head_bias = -math.log(self.num_classes) if nlhb else 0.
  381. named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
  382. @torch.jit.ignore
  383. def group_matcher(self, coarse=False):
  384. return dict(
  385. stem=r'^stem',
  386. blocks=[
  387. (r'^stages\.(\d+)', None),
  388. (r'^norm', (99999,))
  389. ] if coarse else [
  390. (r'^stages\.(\d+)\.blocks\.(\d+)', None),
  391. (r'^stages\.(\d+)\.downsample', (0,)),
  392. (r'^norm', (99999,))
  393. ]
  394. )
  395. @torch.jit.ignore
  396. def set_grad_checkpointing(self, enable=True):
  397. assert not enable, 'gradient checkpointing not supported'
  398. @torch.jit.ignore
  399. def get_classifier(self) -> nn.Module:
  400. return self.head
  401. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  402. self.num_classes = num_classes
  403. self.head.reset(num_classes, pool_type=global_pool)
  404. def forward_features(self, x):
  405. x = self.stem(x)
  406. x = self.stages(x)
  407. x = self.norm(x)
  408. return x
  409. def forward_head(self, x, pre_logits: bool = False):
  410. return self.head(x, pre_logits=True) if pre_logits else self.head(x)
  411. def forward(self, x):
  412. x = self.forward_features(x)
  413. x = self.forward_head(x)
  414. return x
  415. def checkpoint_filter_fn(state_dict, model):
  416. """ Remap original checkpoints -> timm """
  417. if 'stages.0.blocks.0.norm1.weight' in state_dict:
  418. return state_dict # already translated checkpoint
  419. if 'model' in state_dict:
  420. state_dict = state_dict['model']
  421. import re
  422. out_dict = {}
  423. for k, v in state_dict.items():
  424. k = re.sub(r'blocks.([0-9]+).([0-9]+).down', lambda x: f'stages.{int(x.group(1)) + 1}.downsample.down', k)
  425. k = re.sub(r'blocks.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
  426. k = k.replace('head.', 'head.fc.')
  427. out_dict[k] = v
  428. return out_dict
  429. def _create_sequencer2d(variant, pretrained=False, **kwargs):
  430. default_out_indices = tuple(range(3))
  431. out_indices = kwargs.pop('out_indices', default_out_indices)
  432. model = build_model_with_cfg(
  433. Sequencer2d,
  434. variant,
  435. pretrained,
  436. pretrained_filter_fn=checkpoint_filter_fn,
  437. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  438. **kwargs,
  439. )
  440. return model
  441. def _cfg(url='', **kwargs):
  442. return {
  443. 'url': url,
  444. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
  445. 'crop_pct': DEFAULT_CROP_PCT, 'interpolation': 'bicubic', 'fixed_input_size': True,
  446. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  447. 'first_conv': 'stem.proj', 'classifier': 'head.fc',
  448. 'license': 'apache-2.0',
  449. **kwargs
  450. }
  451. default_cfgs = generate_default_cfgs({
  452. 'sequencer2d_s.in1k': _cfg(hf_hub_id='timm/'),
  453. 'sequencer2d_m.in1k': _cfg(hf_hub_id='timm/'),
  454. 'sequencer2d_l.in1k': _cfg(hf_hub_id='timm/'),
  455. })
  456. @register_model
  457. def sequencer2d_s(pretrained=False, **kwargs) -> Sequencer2d:
  458. model_args = dict(
  459. layers=[4, 3, 8, 3],
  460. patch_sizes=[7, 2, 1, 1],
  461. embed_dims=[192, 384, 384, 384],
  462. hidden_sizes=[48, 96, 96, 96],
  463. mlp_ratios=[3.0, 3.0, 3.0, 3.0],
  464. rnn_layer=LSTM2d,
  465. bidirectional=True,
  466. union="cat",
  467. with_fc=True,
  468. )
  469. model = _create_sequencer2d('sequencer2d_s', pretrained=pretrained, **dict(model_args, **kwargs))
  470. return model
  471. @register_model
  472. def sequencer2d_m(pretrained=False, **kwargs) -> Sequencer2d:
  473. model_args = dict(
  474. layers=[4, 3, 14, 3],
  475. patch_sizes=[7, 2, 1, 1],
  476. embed_dims=[192, 384, 384, 384],
  477. hidden_sizes=[48, 96, 96, 96],
  478. mlp_ratios=[3.0, 3.0, 3.0, 3.0],
  479. rnn_layer=LSTM2d,
  480. bidirectional=True,
  481. union="cat",
  482. with_fc=True,
  483. **kwargs)
  484. model = _create_sequencer2d('sequencer2d_m', pretrained=pretrained, **dict(model_args, **kwargs))
  485. return model
  486. @register_model
  487. def sequencer2d_l(pretrained=False, **kwargs) -> Sequencer2d:
  488. model_args = dict(
  489. layers=[8, 8, 16, 4],
  490. patch_sizes=[7, 2, 1, 1],
  491. embed_dims=[192, 384, 384, 384],
  492. hidden_sizes=[48, 96, 96, 96],
  493. mlp_ratios=[3.0, 3.0, 3.0, 3.0],
  494. rnn_layer=LSTM2d,
  495. bidirectional=True,
  496. union="cat",
  497. with_fc=True,
  498. **kwargs)
  499. model = _create_sequencer2d('sequencer2d_l', pretrained=pretrained, **dict(model_args, **kwargs))
  500. return model