rdnet.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. """
  2. RDNet
  3. Copyright (c) 2024-present NAVER Cloud Corp.
  4. Apache-2.0
  5. """
  6. from functools import partial
  7. from typing import List, Optional, Tuple, Union, Callable, Type
  8. import torch
  9. import torch.nn as nn
  10. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  11. from timm.layers import DropPath, calculate_drop_path_rates, NormMlpClassifierHead, ClassifierHead, EffectiveSEModule, \
  12. make_divisible, get_act_layer, get_norm_layer
  13. from ._builder import build_model_with_cfg
  14. from ._features import feature_take_indices
  15. from ._manipulate import named_apply
  16. from ._registry import register_model, generate_default_cfgs
  17. __all__ = ["RDNet"]
  18. class Block(nn.Module):
  19. def __init__(
  20. self,
  21. in_chs: int,
  22. inter_chs: int,
  23. out_chs: int,
  24. norm_layer: Type[nn.Module],
  25. act_layer: Type[nn.Module],
  26. device=None,
  27. dtype=None,
  28. ):
  29. dd = {'device': device, 'dtype': dtype}
  30. super().__init__()
  31. self.layers = nn.Sequential(
  32. nn.Conv2d(in_chs, in_chs, groups=in_chs, kernel_size=7, stride=1, padding=3, **dd),
  33. norm_layer(in_chs, **dd),
  34. nn.Conv2d(in_chs, inter_chs, kernel_size=1, stride=1, padding=0, **dd),
  35. act_layer(),
  36. nn.Conv2d(inter_chs, out_chs, kernel_size=1, stride=1, padding=0, **dd),
  37. )
  38. def forward(self, x):
  39. return self.layers(x)
  40. class BlockESE(nn.Module):
  41. def __init__(
  42. self,
  43. in_chs: int,
  44. inter_chs: int,
  45. out_chs: int,
  46. norm_layer: Type[nn.Module],
  47. act_layer: Type[nn.Module],
  48. device=None,
  49. dtype=None,
  50. ):
  51. dd = {'device': device, 'dtype': dtype}
  52. super().__init__()
  53. self.layers = nn.Sequential(
  54. nn.Conv2d(in_chs, in_chs, groups=in_chs, kernel_size=7, stride=1, padding=3, **dd),
  55. norm_layer(in_chs, **dd),
  56. nn.Conv2d(in_chs, inter_chs, kernel_size=1, stride=1, padding=0, **dd),
  57. act_layer(),
  58. nn.Conv2d(inter_chs, out_chs, kernel_size=1, stride=1, padding=0, **dd),
  59. EffectiveSEModule(out_chs, **dd),
  60. )
  61. def forward(self, x):
  62. return self.layers(x)
  63. def _get_block_type(block: str):
  64. block = block.lower().strip()
  65. if block == "block":
  66. return Block
  67. elif block == "blockese":
  68. return BlockESE
  69. else:
  70. assert False, f"Unknown block type ({block})."
  71. class DenseBlock(nn.Module):
  72. def __init__(
  73. self,
  74. num_input_features: int = 64,
  75. growth_rate: int = 64,
  76. bottleneck_width_ratio: float = 4.0,
  77. drop_path_rate: float = 0.0,
  78. drop_rate: float = 0.0,
  79. rand_gather_step_prob: float = 0.0,
  80. block_idx: int = 0,
  81. block_type: str = "Block",
  82. ls_init_value: float = 1e-6,
  83. norm_layer: Type[nn.Module] = nn.LayerNorm,
  84. act_layer: Type[nn.Module] = nn.GELU,
  85. device=None,
  86. dtype=None,
  87. ):
  88. dd = {'device': device, 'dtype': dtype}
  89. super().__init__()
  90. self.drop_rate = drop_rate
  91. self.drop_path_rate = drop_path_rate
  92. self.rand_gather_step_prob = rand_gather_step_prob
  93. self.block_idx = block_idx
  94. self.growth_rate = growth_rate
  95. self.gamma = nn.Parameter(ls_init_value * torch.ones(growth_rate, **dd)) if ls_init_value > 0 else None
  96. growth_rate = int(growth_rate)
  97. inter_chs = int(num_input_features * bottleneck_width_ratio / 8) * 8
  98. self.drop_path = DropPath(drop_path_rate)
  99. self.layers = _get_block_type(block_type)(
  100. in_chs=num_input_features,
  101. inter_chs=inter_chs,
  102. out_chs=growth_rate,
  103. norm_layer=norm_layer,
  104. act_layer=act_layer,
  105. **dd,
  106. )
  107. def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
  108. x = torch.cat(x, 1)
  109. x = self.layers(x)
  110. if self.gamma is not None:
  111. x = x.mul(self.gamma.reshape(1, -1, 1, 1))
  112. x = self.drop_path(x)
  113. return x
  114. class DenseStage(nn.Sequential):
  115. def __init__(
  116. self,
  117. num_block: int,
  118. num_input_features: int,
  119. drop_path_rates: List[float],
  120. growth_rate: int,
  121. device=None,
  122. dtype=None,
  123. **kwargs,
  124. ):
  125. dd = {'device': device, 'dtype': dtype}
  126. super().__init__()
  127. for i in range(num_block):
  128. layer = DenseBlock(
  129. num_input_features=num_input_features,
  130. growth_rate=growth_rate,
  131. drop_path_rate=drop_path_rates[i],
  132. block_idx=i,
  133. **dd,
  134. **kwargs,
  135. )
  136. num_input_features += growth_rate
  137. self.add_module(f"dense_block{i}", layer)
  138. self.num_out_features = num_input_features
  139. def forward(self, init_feature: torch.Tensor) -> torch.Tensor:
  140. features = [init_feature]
  141. for module in self:
  142. new_feature = module(features)
  143. features.append(new_feature)
  144. return torch.cat(features, 1)
  145. class RDNet(nn.Module):
  146. def __init__(
  147. self,
  148. in_chans: int = 3, # timm option [--in-chans]
  149. num_classes: int = 1000, # timm option [--num-classes]
  150. global_pool: str = 'avg', # timm option [--gp]
  151. growth_rates: Union[List[int], Tuple[int]] = (64, 104, 128, 128, 128, 128, 224),
  152. num_blocks_list: Union[List[int], Tuple[int]] = (3, 3, 3, 3, 3, 3, 3),
  153. block_type: Union[List[int], Tuple[int]] = ("Block",) * 2 + ("BlockESE",) * 5,
  154. is_downsample_block: Union[List[bool], Tuple[bool]] = (None, True, True, False, False, False, True),
  155. bottleneck_width_ratio: float = 4.0,
  156. transition_compression_ratio: float = 0.5,
  157. ls_init_value: float = 1e-6,
  158. stem_type: str = 'patch',
  159. patch_size: int = 4,
  160. num_init_features: int = 64,
  161. head_init_scale: float = 1.,
  162. head_norm_first: bool = False,
  163. conv_bias: bool = True,
  164. act_layer: Union[str, Callable] = 'gelu',
  165. norm_layer: str = "layernorm2d",
  166. norm_eps: Optional[float] = None,
  167. drop_rate: float = 0.0, # timm option [--drop: dropout ratio]
  168. drop_path_rate: float = 0.0, # timm option [--drop-path: drop-path ratio]
  169. device=None,
  170. dtype=None,
  171. ):
  172. """
  173. Args:
  174. in_chans: Number of input image channels.
  175. num_classes: Number of classes for classification head.
  176. global_pool: Global pooling type.
  177. growth_rates: Growth rate at each stage.
  178. num_blocks_list: Number of blocks at each stage.
  179. is_downsample_block: Whether to downsample at each stage.
  180. bottleneck_width_ratio: Bottleneck width ratio (similar to mlp expansion ratio).
  181. transition_compression_ratio: Channel compression ratio of transition layers.
  182. ls_init_value: Init value for Layer Scale, disabled if None.
  183. stem_type: Type of stem.
  184. patch_size: Stem patch size for patch stem.
  185. num_init_features: Number of features of stem.
  186. head_init_scale: Init scaling value for classifier weights and biases.
  187. head_norm_first: Apply normalization before global pool + head.
  188. conv_bias: Use bias layers w/ all convolutions.
  189. act_layer: Activation layer type.
  190. norm_layer: Normalization layer type.
  191. norm_eps: Small value to avoid division by zero in normalization.
  192. drop_rate: Head pre-classifier dropout rate.
  193. drop_path_rate: Stochastic depth drop rate.
  194. """
  195. super().__init__()
  196. dd = {'device': device, 'dtype': dtype}
  197. assert len(growth_rates) == len(num_blocks_list) == len(is_downsample_block)
  198. act_layer = get_act_layer(act_layer)
  199. norm_layer = get_norm_layer(norm_layer)
  200. if norm_eps is not None:
  201. norm_layer = partial(norm_layer, eps=norm_eps)
  202. self.num_classes = num_classes
  203. self.drop_rate = drop_rate
  204. # stem
  205. assert stem_type in ('patch', 'overlap', 'overlap_tiered')
  206. if stem_type == 'patch':
  207. # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
  208. self.stem = nn.Sequential(
  209. nn.Conv2d(in_chans, num_init_features, kernel_size=patch_size, stride=patch_size, bias=conv_bias, **dd),
  210. norm_layer(num_init_features, **dd),
  211. )
  212. stem_stride = patch_size
  213. else:
  214. mid_chs = make_divisible(num_init_features // 2) if 'tiered' in stem_type else num_init_features
  215. self.stem = nn.Sequential(
  216. nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias, **dd),
  217. nn.Conv2d(mid_chs, num_init_features, kernel_size=3, stride=2, padding=1, bias=conv_bias, **dd),
  218. norm_layer(num_init_features, **dd),
  219. )
  220. stem_stride = 4
  221. # features
  222. self.feature_info = []
  223. self.num_stages = len(growth_rates)
  224. curr_stride = stem_stride
  225. num_features = num_init_features
  226. dp_rates = calculate_drop_path_rates(drop_path_rate, num_blocks_list, stagewise=True)
  227. dense_stages = []
  228. for i in range(self.num_stages):
  229. dense_stage_layers = []
  230. if i != 0:
  231. compressed_num_features = int(num_features * transition_compression_ratio / 8) * 8
  232. k_size = stride = 1
  233. if is_downsample_block[i]:
  234. curr_stride *= 2
  235. k_size = stride = 2
  236. dense_stage_layers.append(norm_layer(num_features, **dd))
  237. dense_stage_layers.append(nn.Conv2d(
  238. num_features,
  239. compressed_num_features,
  240. kernel_size=k_size,
  241. stride=stride,
  242. padding=0,
  243. **dd,
  244. ))
  245. num_features = compressed_num_features
  246. stage = DenseStage(
  247. num_block=num_blocks_list[i],
  248. num_input_features=num_features,
  249. growth_rate=growth_rates[i],
  250. bottleneck_width_ratio=bottleneck_width_ratio,
  251. drop_rate=drop_rate,
  252. drop_path_rates=dp_rates[i],
  253. ls_init_value=ls_init_value,
  254. block_type=block_type[i],
  255. norm_layer=norm_layer,
  256. act_layer=act_layer,
  257. **dd,
  258. )
  259. dense_stage_layers.append(stage)
  260. num_features += num_blocks_list[i] * growth_rates[i]
  261. if i + 1 == self.num_stages or (i + 1 != self.num_stages and is_downsample_block[i + 1]):
  262. self.feature_info += [
  263. dict(
  264. num_chs=num_features,
  265. reduction=curr_stride,
  266. module=f'dense_stages.{i}',
  267. growth_rate=growth_rates[i],
  268. )
  269. ]
  270. dense_stages.append(nn.Sequential(*dense_stage_layers))
  271. self.dense_stages = nn.Sequential(*dense_stages)
  272. self.num_features = self.head_hidden_size = num_features
  273. # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
  274. # otherwise pool -> norm -> fc, the default RDNet ordering (pretrained NV weights)
  275. if head_norm_first:
  276. self.norm_pre = norm_layer(self.num_features, **dd)
  277. self.head = ClassifierHead(
  278. self.num_features,
  279. num_classes,
  280. pool_type=global_pool,
  281. drop_rate=self.drop_rate,
  282. **dd,
  283. )
  284. else:
  285. self.norm_pre = nn.Identity()
  286. self.head = NormMlpClassifierHead(
  287. self.num_features,
  288. num_classes,
  289. pool_type=global_pool,
  290. drop_rate=self.drop_rate,
  291. norm_layer=norm_layer,
  292. **dd,
  293. )
  294. named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
  295. @torch.jit.ignore
  296. def group_matcher(self, coarse=False):
  297. assert not coarse, "coarse grouping is not implemented for RDNet"
  298. return dict(
  299. stem=r'^stem',
  300. blocks=r'^dense_stages\.(\d+)',
  301. )
  302. @torch.jit.ignore
  303. def set_grad_checkpointing(self, enable=True):
  304. for s in self.dense_stages:
  305. s.grad_checkpointing = enable
  306. @torch.jit.ignore
  307. def get_classifier(self) -> nn.Module:
  308. return self.head.fc
  309. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
  310. self.num_classes = num_classes
  311. self.head.reset(num_classes, global_pool)
  312. def forward_intermediates(
  313. self,
  314. x: torch.Tensor,
  315. indices: Optional[Union[int, List[int]]] = None,
  316. norm: bool = False,
  317. stop_early: bool = False,
  318. output_fmt: str = 'NCHW',
  319. intermediates_only: bool = False,
  320. ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
  321. """ Forward features that returns intermediates.
  322. Args:
  323. x: Input image tensor
  324. indices: Take last n blocks if int, all if None, select matching indices if sequence
  325. norm: Apply norm layer to compatible intermediates
  326. stop_early: Stop iterating over blocks when last desired intermediate hit
  327. output_fmt: Shape of intermediate feature outputs
  328. intermediates_only: Only return intermediate features
  329. """
  330. assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
  331. intermediates = []
  332. stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
  333. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  334. take_indices = [stage_ends[i] for i in take_indices]
  335. max_index = stage_ends[max_index]
  336. # forward pass
  337. x = self.stem(x)
  338. last_idx = len(self.dense_stages) - 1
  339. if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
  340. dense_stages = self.dense_stages
  341. else:
  342. dense_stages = self.dense_stages[:max_index + 1]
  343. for feat_idx, stage in enumerate(dense_stages):
  344. x = stage(x)
  345. if feat_idx in take_indices:
  346. if norm and feat_idx == last_idx:
  347. x_inter = self.norm_pre(x) # applying final norm to last intermediate
  348. else:
  349. x_inter = x
  350. intermediates.append(x_inter)
  351. if intermediates_only:
  352. return intermediates
  353. if feat_idx == last_idx:
  354. x = self.norm_pre(x)
  355. return x, intermediates
  356. def prune_intermediate_layers(
  357. self,
  358. indices: Union[int, List[int]] = 1,
  359. prune_norm: bool = False,
  360. prune_head: bool = True,
  361. ):
  362. """ Prune layers not required for specified intermediates.
  363. """
  364. stage_ends = [int(info['module'].split('.')[-1]) for info in self.feature_info]
  365. take_indices, max_index = feature_take_indices(len(stage_ends), indices)
  366. max_index = stage_ends[max_index]
  367. self.dense_stages = self.dense_stages[:max_index + 1] # truncate blocks w/ stem as idx 0
  368. if prune_norm:
  369. self.norm_pre = nn.Identity()
  370. if prune_head:
  371. self.reset_classifier(0, '')
  372. return take_indices
  373. def forward_features(self, x):
  374. x = self.stem(x)
  375. x = self.dense_stages(x)
  376. x = self.norm_pre(x)
  377. return x
  378. def forward_head(self, x, pre_logits: bool = False):
  379. return self.head(x, pre_logits=True) if pre_logits else self.head(x)
  380. def forward(self, x):
  381. x = self.forward_features(x)
  382. x = self.forward_head(x)
  383. return x
  384. def _init_weights(module, name=None, head_init_scale=1.0):
  385. if isinstance(module, nn.Conv2d):
  386. nn.init.kaiming_normal_(module.weight)
  387. elif isinstance(module, nn.BatchNorm2d):
  388. nn.init.constant_(module.weight, 1)
  389. nn.init.constant_(module.bias, 0)
  390. elif isinstance(module, nn.Linear):
  391. nn.init.constant_(module.bias, 0)
  392. if name and 'head.' in name:
  393. module.weight.data.mul_(head_init_scale)
  394. module.bias.data.mul_(head_init_scale)
  395. def checkpoint_filter_fn(state_dict, model):
  396. """ Remap NV checkpoints -> timm """
  397. if 'stem.0.weight' in state_dict:
  398. return state_dict # non-NV checkpoint
  399. if 'model' in state_dict:
  400. state_dict = state_dict['model']
  401. out_dict = {}
  402. for k, v in state_dict.items():
  403. k = k.replace('stem.stem.', 'stem.')
  404. out_dict[k] = v
  405. return out_dict
  406. def _create_rdnet(variant, pretrained=False, **kwargs):
  407. model = build_model_with_cfg(
  408. RDNet, variant, pretrained,
  409. pretrained_filter_fn=checkpoint_filter_fn,
  410. feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
  411. **kwargs)
  412. return model
  413. def _cfg(url='', **kwargs):
  414. return {
  415. "url": url,
  416. "num_classes": 1000, "input_size": (3, 224, 224), "pool_size": (7, 7),
  417. "crop_pct": 0.9, "interpolation": "bicubic",
  418. "mean": IMAGENET_DEFAULT_MEAN, "std": IMAGENET_DEFAULT_STD,
  419. "first_conv": "stem.0", "classifier": "head.fc",
  420. "paper_ids": "arXiv:2403.19588",
  421. "paper_name": "DenseNets Reloaded: Paradigm Shift Beyond ResNets and ViTs",
  422. "origin_url": "https://github.com/naver-ai/rdnet",
  423. "license": "apache-2.0",
  424. **kwargs,
  425. }
  426. default_cfgs = generate_default_cfgs({
  427. 'rdnet_tiny.nv_in1k': _cfg(
  428. hf_hub_id='naver-ai/rdnet_tiny.nv_in1k'),
  429. 'rdnet_small.nv_in1k': _cfg(
  430. hf_hub_id='naver-ai/rdnet_small.nv_in1k'),
  431. 'rdnet_base.nv_in1k': _cfg(
  432. hf_hub_id='naver-ai/rdnet_base.nv_in1k'),
  433. 'rdnet_large.nv_in1k': _cfg(
  434. hf_hub_id='naver-ai/rdnet_large.nv_in1k'),
  435. 'rdnet_large.nv_in1k_ft_in1k_384': _cfg(
  436. hf_hub_id='naver-ai/rdnet_large.nv_in1k_ft_in1k_384',
  437. input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
  438. })
  439. @register_model
  440. def rdnet_tiny(pretrained=False, **kwargs):
  441. n_layer = 7
  442. model_args = {
  443. "num_init_features": 64,
  444. "growth_rates": [64] + [104] + [128] * 4 + [224],
  445. "num_blocks_list": [3] * n_layer,
  446. "is_downsample_block": (None, True, True, False, False, False, True),
  447. "transition_compression_ratio": 0.5,
  448. "block_type": ["Block"] + ["Block"] + ["BlockESE"] * 4 + ["BlockESE"],
  449. }
  450. model = _create_rdnet("rdnet_tiny", pretrained=pretrained, **dict(model_args, **kwargs))
  451. return model
  452. @register_model
  453. def rdnet_small(pretrained=False, **kwargs):
  454. n_layer = 11
  455. model_args = {
  456. "num_init_features": 72,
  457. "growth_rates": [64] + [128] + [128] * (n_layer - 4) + [240] * 2,
  458. "num_blocks_list": [3] * n_layer,
  459. "is_downsample_block": (None, True, True, False, False, False, False, False, False, True, False),
  460. "transition_compression_ratio": 0.5,
  461. "block_type": ["Block"] + ["Block"] + ["BlockESE"] * (n_layer - 4) + ["BlockESE"] * 2,
  462. }
  463. model = _create_rdnet("rdnet_small", pretrained=pretrained, **dict(model_args, **kwargs))
  464. return model
  465. @register_model
  466. def rdnet_base(pretrained=False, **kwargs):
  467. n_layer = 11
  468. model_args = {
  469. "num_init_features": 120,
  470. "growth_rates": [96] + [128] + [168] * (n_layer - 4) + [336] * 2,
  471. "num_blocks_list": [3] * n_layer,
  472. "is_downsample_block": (None, True, True, False, False, False, False, False, False, True, False),
  473. "transition_compression_ratio": 0.5,
  474. "block_type": ["Block"] + ["Block"] + ["BlockESE"] * (n_layer - 4) + ["BlockESE"] * 2,
  475. }
  476. model = _create_rdnet("rdnet_base", pretrained=pretrained, **dict(model_args, **kwargs))
  477. return model
  478. @register_model
  479. def rdnet_large(pretrained=False, **kwargs):
  480. n_layer = 12
  481. model_args = {
  482. "num_init_features": 144,
  483. "growth_rates": [128] + [192] + [256] * (n_layer - 4) + [360] * 2,
  484. "num_blocks_list": [3] * n_layer,
  485. "is_downsample_block": (None, True, True, False, False, False, False, False, False, False, True, False),
  486. "transition_compression_ratio": 0.5,
  487. "block_type": ["Block"] + ["Block"] + ["BlockESE"] * (n_layer - 4) + ["BlockESE"] * 2,
  488. }
  489. model = _create_rdnet("rdnet_large", pretrained=pretrained, **dict(model_args, **kwargs))
  490. return model