vgg.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. """VGG
  2. Adapted from https://github.com/pytorch/vision 'vgg.py' (BSD-3-Clause) with a few changes for
  3. timm functionality.
  4. Copyright 2021 Ross Wightman
  5. """
  6. from typing import Any, Dict, List, Optional, Type, Union, cast
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
  11. from timm.layers import ClassifierHead
  12. from ._builder import build_model_with_cfg
  13. from ._features_fx import register_notrace_module
  14. from ._registry import register_model, generate_default_cfgs
  15. __all__ = ['VGG']
  16. cfgs: Dict[str, List[Union[str, int]]] = {
  17. 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  18. 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  19. 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
  20. 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
  21. }
  22. @register_notrace_module # reason: FX can't symbolically trace control flow in forward method
  23. class ConvMlp(nn.Module):
  24. """Convolutional MLP block for VGG head.
  25. Replaces traditional Linear layers with Conv2d layers in the classifier.
  26. """
  27. def __init__(
  28. self,
  29. in_features: int = 512,
  30. out_features: int = 4096,
  31. kernel_size: int = 7,
  32. mlp_ratio: float = 1.0,
  33. drop_rate: float = 0.2,
  34. act_layer: Type[nn.Module] = nn.ReLU,
  35. conv_layer: Type[nn.Module] = nn.Conv2d,
  36. device=None,
  37. dtype=None,
  38. ) -> None:
  39. """Initialize ConvMlp.
  40. Args:
  41. in_features: Number of input features.
  42. out_features: Number of output features.
  43. kernel_size: Kernel size for first conv layer.
  44. mlp_ratio: Ratio for hidden layer size.
  45. drop_rate: Dropout rate.
  46. act_layer: Activation layer type.
  47. conv_layer: Convolution layer type.
  48. """
  49. dd = {'device': device, 'dtype': dtype}
  50. super().__init__()
  51. self.input_kernel_size = kernel_size
  52. mid_features = int(out_features * mlp_ratio)
  53. self.fc1 = conv_layer(in_features, mid_features, kernel_size, bias=True, **dd)
  54. self.act1 = act_layer(True)
  55. self.drop = nn.Dropout(drop_rate)
  56. self.fc2 = conv_layer(mid_features, out_features, 1, bias=True, **dd)
  57. self.act2 = act_layer(True)
  58. def forward(self, x: torch.Tensor) -> torch.Tensor:
  59. """Forward pass.
  60. Args:
  61. x: Input tensor.
  62. Returns:
  63. Output tensor.
  64. """
  65. if x.shape[-2] < self.input_kernel_size or x.shape[-1] < self.input_kernel_size:
  66. # keep the input size >= 7x7
  67. output_size = (max(self.input_kernel_size, x.shape[-2]), max(self.input_kernel_size, x.shape[-1]))
  68. x = F.adaptive_avg_pool2d(x, output_size)
  69. x = self.fc1(x)
  70. x = self.act1(x)
  71. x = self.drop(x)
  72. x = self.fc2(x)
  73. x = self.act2(x)
  74. return x
  75. class VGG(nn.Module):
  76. """VGG model architecture.
  77. Based on `Very Deep Convolutional Networks for Large-Scale Image Recognition`
  78. - https://arxiv.org/abs/1409.1556
  79. """
  80. def __init__(
  81. self,
  82. cfg: List[Any],
  83. num_classes: int = 1000,
  84. in_chans: int = 3,
  85. output_stride: int = 32,
  86. mlp_ratio: float = 1.0,
  87. act_layer: Type[nn.Module] = nn.ReLU,
  88. conv_layer: Type[nn.Module] = nn.Conv2d,
  89. norm_layer: Optional[Type[nn.Module]] = None,
  90. global_pool: str = 'avg',
  91. drop_rate: float = 0.,
  92. device=None,
  93. dtype=None,
  94. ) -> None:
  95. """Initialize VGG model.
  96. Args:
  97. cfg: Configuration list defining network architecture.
  98. num_classes: Number of classes for classification.
  99. in_chans: Number of input channels.
  100. output_stride: Output stride of network.
  101. mlp_ratio: Ratio for MLP hidden layer size.
  102. act_layer: Activation layer type.
  103. conv_layer: Convolution layer type.
  104. norm_layer: Normalization layer type.
  105. global_pool: Global pooling type.
  106. drop_rate: Dropout rate.
  107. """
  108. super().__init__()
  109. dd = {'device': device, 'dtype': dtype}
  110. assert output_stride == 32
  111. self.num_classes = num_classes
  112. self.drop_rate = drop_rate
  113. self.grad_checkpointing = False
  114. self.use_norm = norm_layer is not None
  115. self.feature_info = []
  116. prev_chs = in_chans
  117. net_stride = 1
  118. pool_layer = nn.MaxPool2d
  119. layers: List[nn.Module] = []
  120. for v in cfg:
  121. last_idx = len(layers) - 1
  122. if v == 'M':
  123. self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=f'features.{last_idx}'))
  124. layers += [pool_layer(kernel_size=2, stride=2)]
  125. net_stride *= 2
  126. else:
  127. v = cast(int, v)
  128. conv2d = conv_layer(prev_chs, v, kernel_size=3, padding=1, **dd)
  129. if norm_layer is not None:
  130. layers += [conv2d, norm_layer(v, **dd), act_layer(inplace=True)]
  131. else:
  132. layers += [conv2d, act_layer(inplace=True)]
  133. prev_chs = v
  134. self.features = nn.Sequential(*layers)
  135. self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=f'features.{len(layers) - 1}'))
  136. self.num_features = prev_chs
  137. self.head_hidden_size = 4096
  138. self.pre_logits = ConvMlp(
  139. prev_chs,
  140. self.head_hidden_size,
  141. 7,
  142. mlp_ratio=mlp_ratio,
  143. drop_rate=drop_rate,
  144. act_layer=act_layer,
  145. conv_layer=conv_layer,
  146. **dd,
  147. )
  148. self.head = ClassifierHead(
  149. self.head_hidden_size,
  150. num_classes,
  151. pool_type=global_pool,
  152. drop_rate=drop_rate,
  153. **dd,
  154. )
  155. self._initialize_weights()
  156. @torch.jit.ignore
  157. def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
  158. """Group matcher for parameter groups.
  159. Args:
  160. coarse: Whether to use coarse grouping.
  161. Returns:
  162. Dictionary of grouped parameters.
  163. """
  164. # this treats BN layers as separate groups for bn variants, a lot of effort to fix that
  165. return dict(stem=r'^features\.0', blocks=r'^features\.(\d+)')
  166. @torch.jit.ignore
  167. def set_grad_checkpointing(self, enable: bool = True) -> None:
  168. """Enable or disable gradient checkpointing.
  169. Args:
  170. enable: Whether to enable gradient checkpointing.
  171. """
  172. assert not enable, 'gradient checkpointing not supported'
  173. @torch.jit.ignore
  174. def get_classifier(self) -> nn.Module:
  175. """Get the classifier module.
  176. Returns:
  177. Classifier module.
  178. """
  179. return self.head.fc
  180. def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
  181. """Reset the classifier.
  182. Args:
  183. num_classes: Number of classes for new classifier.
  184. global_pool: Global pooling type.
  185. """
  186. self.num_classes = num_classes
  187. self.head.reset(num_classes, global_pool)
  188. def forward_features(self, x: torch.Tensor) -> torch.Tensor:
  189. """Forward pass through feature extraction layers.
  190. Args:
  191. x: Input tensor.
  192. Returns:
  193. Feature tensor.
  194. """
  195. x = self.features(x)
  196. return x
  197. def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
  198. """Forward pass through head.
  199. Args:
  200. x: Input features.
  201. pre_logits: Return features before final linear layer.
  202. Returns:
  203. Classification logits or features.
  204. """
  205. x = self.pre_logits(x)
  206. return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
  207. def forward(self, x: torch.Tensor) -> torch.Tensor:
  208. """Forward pass.
  209. Args:
  210. x: Input tensor.
  211. Returns:
  212. Output logits.
  213. """
  214. x = self.forward_features(x)
  215. x = self.forward_head(x)
  216. return x
  217. def _initialize_weights(self) -> None:
  218. """Initialize model weights."""
  219. for m in self.modules():
  220. if isinstance(m, nn.Conv2d):
  221. nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
  222. if m.bias is not None:
  223. nn.init.constant_(m.bias, 0)
  224. elif isinstance(m, nn.BatchNorm2d):
  225. nn.init.constant_(m.weight, 1)
  226. nn.init.constant_(m.bias, 0)
  227. elif isinstance(m, nn.Linear):
  228. nn.init.normal_(m.weight, 0, 0.01)
  229. nn.init.constant_(m.bias, 0)
  230. def _filter_fn(state_dict: dict) -> Dict[str, torch.Tensor]:
  231. """Convert patch embedding weight from manual patchify + linear proj to conv.
  232. Args:
  233. state_dict: State dictionary to filter.
  234. Returns:
  235. Filtered state dictionary.
  236. """
  237. out_dict = {}
  238. for k, v in state_dict.items():
  239. k_r = k
  240. k_r = k_r.replace('classifier.0', 'pre_logits.fc1')
  241. k_r = k_r.replace('classifier.3', 'pre_logits.fc2')
  242. k_r = k_r.replace('classifier.6', 'head.fc')
  243. if 'classifier.0.weight' in k:
  244. v = v.reshape(-1, 512, 7, 7)
  245. if 'classifier.3.weight' in k:
  246. v = v.reshape(-1, 4096, 1, 1)
  247. out_dict[k_r] = v
  248. return out_dict
  249. def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG:
  250. """Create a VGG model.
  251. Args:
  252. variant: Model variant name.
  253. pretrained: Load pretrained weights.
  254. **kwargs: Additional model arguments.
  255. Returns:
  256. VGG model instance.
  257. """
  258. cfg = variant.split('_')[0]
  259. # NOTE: VGG is one of few models with stride==1 features w/ 6 out_indices [0..5]
  260. out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5))
  261. model = build_model_with_cfg(
  262. VGG,
  263. variant,
  264. pretrained,
  265. model_cfg=cfgs[cfg],
  266. feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
  267. pretrained_filter_fn=_filter_fn,
  268. **kwargs,
  269. )
  270. return model
  271. def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
  272. """Create default configuration dictionary.
  273. Args:
  274. url: Model weight URL.
  275. **kwargs: Additional configuration options.
  276. Returns:
  277. Configuration dictionary.
  278. """
  279. return {
  280. 'url': url,
  281. 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
  282. 'crop_pct': 0.875, 'interpolation': 'bilinear',
  283. 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
  284. 'first_conv': 'features.0', 'classifier': 'head.fc',
  285. 'license': 'bsd-3-clause',
  286. **kwargs
  287. }
  288. default_cfgs = generate_default_cfgs({
  289. 'vgg11.tv_in1k': _cfg(hf_hub_id='timm/'),
  290. 'vgg13.tv_in1k': _cfg(hf_hub_id='timm/'),
  291. 'vgg16.tv_in1k': _cfg(hf_hub_id='timm/'),
  292. 'vgg19.tv_in1k': _cfg(hf_hub_id='timm/'),
  293. 'vgg11_bn.tv_in1k': _cfg(hf_hub_id='timm/'),
  294. 'vgg13_bn.tv_in1k': _cfg(hf_hub_id='timm/'),
  295. 'vgg16_bn.tv_in1k': _cfg(hf_hub_id='timm/'),
  296. 'vgg19_bn.tv_in1k': _cfg(hf_hub_id='timm/'),
  297. })
  298. @register_model
  299. def vgg11(pretrained: bool = False, **kwargs: Any) -> VGG:
  300. r"""VGG 11-layer model (configuration "A") from
  301. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
  302. """
  303. model_args = dict(**kwargs)
  304. return _create_vgg('vgg11', pretrained=pretrained, **model_args)
  305. @register_model
  306. def vgg11_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
  307. r"""VGG 11-layer model (configuration "A") with batch normalization
  308. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
  309. """
  310. model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
  311. return _create_vgg('vgg11_bn', pretrained=pretrained, **model_args)
  312. @register_model
  313. def vgg13(pretrained: bool = False, **kwargs: Any) -> VGG:
  314. r"""VGG 13-layer model (configuration "B")
  315. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
  316. """
  317. model_args = dict(**kwargs)
  318. return _create_vgg('vgg13', pretrained=pretrained, **model_args)
  319. @register_model
  320. def vgg13_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
  321. r"""VGG 13-layer model (configuration "B") with batch normalization
  322. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
  323. """
  324. model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
  325. return _create_vgg('vgg13_bn', pretrained=pretrained, **model_args)
  326. @register_model
  327. def vgg16(pretrained: bool = False, **kwargs: Any) -> VGG:
  328. r"""VGG 16-layer model (configuration "D")
  329. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
  330. """
  331. model_args = dict(**kwargs)
  332. return _create_vgg('vgg16', pretrained=pretrained, **model_args)
  333. @register_model
  334. def vgg16_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
  335. r"""VGG 16-layer model (configuration "D") with batch normalization
  336. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
  337. """
  338. model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
  339. return _create_vgg('vgg16_bn', pretrained=pretrained, **model_args)
  340. @register_model
  341. def vgg19(pretrained: bool = False, **kwargs: Any) -> VGG:
  342. r"""VGG 19-layer model (configuration "E")
  343. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
  344. """
  345. model_args = dict(**kwargs)
  346. return _create_vgg('vgg19', pretrained=pretrained, **model_args)
  347. @register_model
  348. def vgg19_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
  349. r"""VGG 19-layer model (configuration 'E') with batch normalization
  350. `"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`._
  351. """
  352. model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
  353. return _create_vgg('vgg19_bn', pretrained=pretrained, **model_args)