instancenorm.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. import torch.nn.functional as F
  4. from torch import Tensor
  5. from .batchnorm import _LazyNormBase, _NormBase
  6. __all__ = [
  7. "InstanceNorm1d",
  8. "InstanceNorm2d",
  9. "InstanceNorm3d",
  10. "LazyInstanceNorm1d",
  11. "LazyInstanceNorm2d",
  12. "LazyInstanceNorm3d",
  13. ]
  14. class _InstanceNorm(_NormBase):
  15. def __init__(
  16. self,
  17. num_features: int,
  18. eps: float = 1e-5,
  19. momentum: float = 0.1,
  20. affine: bool = False,
  21. track_running_stats: bool = False,
  22. device=None,
  23. dtype=None,
  24. ) -> None:
  25. factory_kwargs = {"device": device, "dtype": dtype}
  26. super().__init__(
  27. num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
  28. )
  29. def _check_input_dim(self, input):
  30. raise NotImplementedError
  31. def _get_no_batch_dim(self):
  32. raise NotImplementedError
  33. def _handle_no_batch_input(self, input):
  34. return self._apply_instance_norm(input.unsqueeze(0)).squeeze(0)
  35. def _apply_instance_norm(self, input):
  36. return F.instance_norm(
  37. input,
  38. self.running_mean,
  39. self.running_var,
  40. self.weight,
  41. self.bias,
  42. self.training or not self.track_running_stats,
  43. self.momentum if self.momentum is not None else 0.0,
  44. self.eps,
  45. )
  46. def _load_from_state_dict(
  47. self,
  48. state_dict,
  49. prefix,
  50. local_metadata,
  51. strict,
  52. missing_keys,
  53. unexpected_keys,
  54. error_msgs,
  55. ) -> None:
  56. version = local_metadata.get("version", None)
  57. # at version 1: removed running_mean and running_var when
  58. # track_running_stats=False (default)
  59. if version is None and not self.track_running_stats:
  60. running_stats_keys = []
  61. for name in ("running_mean", "running_var"):
  62. key = prefix + name
  63. if key in state_dict:
  64. running_stats_keys.append(key)
  65. if len(running_stats_keys) > 0:
  66. error_msgs.append(
  67. "Unexpected running stats buffer(s) {names} for {klass} "
  68. "with track_running_stats=False. If state_dict is a "
  69. "checkpoint saved before 0.4.0, this may be expected "
  70. "because {klass} does not track running stats by default "
  71. "since 0.4.0. Please remove these keys from state_dict. If "
  72. "the running stats are actually needed, instead set "
  73. "track_running_stats=True in {klass} to enable them. See "
  74. "the documentation of {klass} for details.".format(
  75. names=" and ".join(f'"{k}"' for k in running_stats_keys),
  76. klass=self.__class__.__name__,
  77. )
  78. )
  79. for key in running_stats_keys:
  80. state_dict.pop(key)
  81. super()._load_from_state_dict(
  82. state_dict,
  83. prefix,
  84. local_metadata,
  85. strict,
  86. missing_keys,
  87. unexpected_keys,
  88. error_msgs,
  89. )
  90. def forward(self, input: Tensor) -> Tensor:
  91. self._check_input_dim(input)
  92. feature_dim = input.dim() - self._get_no_batch_dim()
  93. if input.size(feature_dim) != self.num_features:
  94. if self.affine:
  95. raise ValueError(
  96. f"expected input's size at dim={feature_dim} to match num_features"
  97. f" ({self.num_features}), but got: {input.size(feature_dim)}."
  98. )
  99. else:
  100. warnings.warn(
  101. f"input's size at dim={feature_dim} does not match num_features. "
  102. "You can silence this warning by not passing in num_features, "
  103. "which is not used because affine=False"
  104. )
  105. if input.dim() == self._get_no_batch_dim():
  106. return self._handle_no_batch_input(input)
  107. return self._apply_instance_norm(input)
  108. class InstanceNorm1d(_InstanceNorm):
  109. r"""Applies Instance Normalization.
  110. This operation applies Instance Normalization
  111. over a 2D (unbatched) or 3D (batched) input as described in the paper
  112. `Instance Normalization: The Missing Ingredient for Fast Stylization
  113. <https://arxiv.org/abs/1607.08022>`__.
  114. .. math::
  115. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  116. The mean and standard-deviation are calculated per-dimension separately
  117. for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  118. of size `C` (where `C` is the number of features or channels of the input) if :attr:`affine` is ``True``.
  119. The variance is calculated via the biased estimator, equivalent to
  120. `torch.var(input, unbiased=False)`.
  121. By default, this layer uses instance statistics computed from input data in
  122. both training and evaluation modes.
  123. If :attr:`track_running_stats` is set to ``True``, during training this
  124. layer keeps running estimates of its computed mean and variance, which are
  125. then used for normalization during evaluation. The running estimates are
  126. kept with a default :attr:`momentum` of 0.1.
  127. .. note::
  128. This :attr:`momentum` argument is different from one used in optimizer
  129. classes and the conventional notion of momentum. Mathematically, the
  130. update rule for running statistics here is
  131. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  132. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  133. new observed value.
  134. .. note::
  135. :class:`InstanceNorm1d` and :class:`LayerNorm` are very similar, but
  136. have some subtle differences. :class:`InstanceNorm1d` is applied
  137. on each channel of channeled data like multidimensional time series, but
  138. :class:`LayerNorm` is usually applied on entire sample and often in NLP
  139. tasks. Additionally, :class:`LayerNorm` applies elementwise affine
  140. transform, while :class:`InstanceNorm1d` usually don't apply affine
  141. transform.
  142. Args:
  143. num_features: number of features or channels :math:`C` of the input
  144. eps: a value added to the denominator for numerical stability. Default: 1e-5
  145. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  146. affine: a boolean value that when set to ``True``, this module has
  147. learnable affine parameters, initialized the same way as done for batch normalization.
  148. Default: ``False``.
  149. track_running_stats: a boolean value that when set to ``True``, this
  150. module tracks the running mean and variance, and when set to ``False``,
  151. this module does not track such statistics and always uses batch
  152. statistics in both training and eval modes. Default: ``False``
  153. Shape:
  154. - Input: :math:`(N, C, L)` or :math:`(C, L)`
  155. - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input)
  156. Examples::
  157. >>> # Without Learnable Parameters
  158. >>> m = nn.InstanceNorm1d(100)
  159. >>> # With Learnable Parameters
  160. >>> m = nn.InstanceNorm1d(100, affine=True)
  161. >>> input = torch.randn(20, 100, 40)
  162. >>> output = m(input)
  163. """
  164. def _get_no_batch_dim(self) -> int:
  165. return 2
  166. def _check_input_dim(self, input) -> None:
  167. if input.dim() not in (2, 3):
  168. raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
  169. class LazyInstanceNorm1d(_LazyNormBase, _InstanceNorm):
  170. r"""A :class:`torch.nn.InstanceNorm1d` module with lazy initialization of the ``num_features`` argument.
  171. The ``num_features`` argument of the :class:`InstanceNorm1d` is inferred from the ``input.size(1)``.
  172. The attributes that will be lazily initialized are `weight`, `bias`, `running_mean` and `running_var`.
  173. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  174. on lazy modules and their limitations.
  175. Args:
  176. num_features: :math:`C` from an expected input of size
  177. :math:`(N, C, L)` or :math:`(C, L)`
  178. eps: a value added to the denominator for numerical stability. Default: 1e-5
  179. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  180. affine: a boolean value that when set to ``True``, this module has
  181. learnable affine parameters, initialized the same way as done for batch normalization.
  182. Default: ``False``.
  183. track_running_stats: a boolean value that when set to ``True``, this
  184. module tracks the running mean and variance, and when set to ``False``,
  185. this module does not track such statistics and always uses batch
  186. statistics in both training and eval modes. Default: ``False``
  187. Shape:
  188. - Input: :math:`(N, C, L)` or :math:`(C, L)`
  189. - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input)
  190. """
  191. cls_to_become = InstanceNorm1d # type: ignore[assignment]
  192. def _get_no_batch_dim(self) -> int:
  193. return 2
  194. def _check_input_dim(self, input) -> None:
  195. if input.dim() not in (2, 3):
  196. raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
  197. class InstanceNorm2d(_InstanceNorm):
  198. r"""Applies Instance Normalization.
  199. This operation applies Instance Normalization
  200. over a 4D input (a mini-batch of 2D inputs
  201. with additional channel dimension) as described in the paper
  202. `Instance Normalization: The Missing Ingredient for Fast Stylization
  203. <https://arxiv.org/abs/1607.08022>`__.
  204. .. math::
  205. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  206. The mean and standard-deviation are calculated per-dimension separately
  207. for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  208. of size `C` (where `C` is the input size) if :attr:`affine` is ``True``.
  209. The standard-deviation is calculated via the biased estimator, equivalent to
  210. `torch.var(input, unbiased=False)`.
  211. By default, this layer uses instance statistics computed from input data in
  212. both training and evaluation modes.
  213. If :attr:`track_running_stats` is set to ``True``, during training this
  214. layer keeps running estimates of its computed mean and variance, which are
  215. then used for normalization during evaluation. The running estimates are
  216. kept with a default :attr:`momentum` of 0.1.
  217. .. note::
  218. This :attr:`momentum` argument is different from one used in optimizer
  219. classes and the conventional notion of momentum. Mathematically, the
  220. update rule for running statistics here is
  221. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  222. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  223. new observed value.
  224. .. note::
  225. :class:`InstanceNorm2d` and :class:`LayerNorm` are very similar, but
  226. have some subtle differences. :class:`InstanceNorm2d` is applied
  227. on each channel of channeled data like RGB images, but
  228. :class:`LayerNorm` is usually applied on entire sample and often in NLP
  229. tasks. Additionally, :class:`LayerNorm` applies elementwise affine
  230. transform, while :class:`InstanceNorm2d` usually don't apply affine
  231. transform.
  232. Args:
  233. num_features: :math:`C` from an expected input of size
  234. :math:`(N, C, H, W)` or :math:`(C, H, W)`
  235. eps: a value added to the denominator for numerical stability. Default: 1e-5
  236. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  237. affine: a boolean value that when set to ``True``, this module has
  238. learnable affine parameters, initialized the same way as done for batch normalization.
  239. Default: ``False``.
  240. track_running_stats: a boolean value that when set to ``True``, this
  241. module tracks the running mean and variance, and when set to ``False``,
  242. this module does not track such statistics and always uses batch
  243. statistics in both training and eval modes. Default: ``False``
  244. Shape:
  245. - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`
  246. - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
  247. Examples::
  248. >>> # Without Learnable Parameters
  249. >>> m = nn.InstanceNorm2d(100)
  250. >>> # With Learnable Parameters
  251. >>> m = nn.InstanceNorm2d(100, affine=True)
  252. >>> input = torch.randn(20, 100, 35, 45)
  253. >>> output = m(input)
  254. """
  255. def _get_no_batch_dim(self) -> int:
  256. return 3
  257. def _check_input_dim(self, input) -> None:
  258. if input.dim() not in (3, 4):
  259. raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)")
  260. class LazyInstanceNorm2d(_LazyNormBase, _InstanceNorm):
  261. r"""A :class:`torch.nn.InstanceNorm2d` module with lazy initialization of the ``num_features`` argument.
  262. The ``num_features`` argument of the :class:`InstanceNorm2d` is inferred from the ``input.size(1)``.
  263. The attributes that will be lazily initialized are `weight`, `bias`,
  264. `running_mean` and `running_var`.
  265. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  266. on lazy modules and their limitations.
  267. Args:
  268. num_features: :math:`C` from an expected input of size
  269. :math:`(N, C, H, W)` or :math:`(C, H, W)`
  270. eps: a value added to the denominator for numerical stability. Default: 1e-5
  271. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  272. affine: a boolean value that when set to ``True``, this module has
  273. learnable affine parameters, initialized the same way as done for batch normalization.
  274. Default: ``False``.
  275. track_running_stats: a boolean value that when set to ``True``, this
  276. module tracks the running mean and variance, and when set to ``False``,
  277. this module does not track such statistics and always uses batch
  278. statistics in both training and eval modes. Default: ``False``
  279. Shape:
  280. - Input: :math:`(N, C, H, W)` or :math:`(C, H, W)`
  281. - Output: :math:`(N, C, H, W)` or :math:`(C, H, W)` (same shape as input)
  282. """
  283. cls_to_become = InstanceNorm2d # type: ignore[assignment]
  284. def _get_no_batch_dim(self) -> int:
  285. return 3
  286. def _check_input_dim(self, input) -> None:
  287. if input.dim() not in (3, 4):
  288. raise ValueError(f"expected 3D or 4D input (got {input.dim()}D input)")
  289. class InstanceNorm3d(_InstanceNorm):
  290. r"""Applies Instance Normalization.
  291. This operation applies Instance Normalization
  292. over a 5D input (a mini-batch of 3D inputs with additional channel dimension) as described in the paper
  293. `Instance Normalization: The Missing Ingredient for Fast Stylization
  294. <https://arxiv.org/abs/1607.08022>`__.
  295. .. math::
  296. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  297. The mean and standard-deviation are calculated per-dimension separately
  298. for each object in a mini-batch. :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  299. of size C (where C is the input size) if :attr:`affine` is ``True``.
  300. The standard-deviation is calculated via the biased estimator, equivalent to
  301. `torch.var(input, unbiased=False)`.
  302. By default, this layer uses instance statistics computed from input data in
  303. both training and evaluation modes.
  304. If :attr:`track_running_stats` is set to ``True``, during training this
  305. layer keeps running estimates of its computed mean and variance, which are
  306. then used for normalization during evaluation. The running estimates are
  307. kept with a default :attr:`momentum` of 0.1.
  308. .. note::
  309. This :attr:`momentum` argument is different from one used in optimizer
  310. classes and the conventional notion of momentum. Mathematically, the
  311. update rule for running statistics here is
  312. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  313. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  314. new observed value.
  315. .. note::
  316. :class:`InstanceNorm3d` and :class:`LayerNorm` are very similar, but
  317. have some subtle differences. :class:`InstanceNorm3d` is applied
  318. on each channel of channeled data like 3D models with RGB color, but
  319. :class:`LayerNorm` is usually applied on entire sample and often in NLP
  320. tasks. Additionally, :class:`LayerNorm` applies elementwise affine
  321. transform, while :class:`InstanceNorm3d` usually don't apply affine
  322. transform.
  323. Args:
  324. num_features: :math:`C` from an expected input of size
  325. :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
  326. eps: a value added to the denominator for numerical stability. Default: 1e-5
  327. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  328. affine: a boolean value that when set to ``True``, this module has
  329. learnable affine parameters, initialized the same way as done for batch normalization.
  330. Default: ``False``.
  331. track_running_stats: a boolean value that when set to ``True``, this
  332. module tracks the running mean and variance, and when set to ``False``,
  333. this module does not track such statistics and always uses batch
  334. statistics in both training and eval modes. Default: ``False``
  335. Shape:
  336. - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
  337. - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input)
  338. Examples::
  339. >>> # Without Learnable Parameters
  340. >>> m = nn.InstanceNorm3d(100)
  341. >>> # With Learnable Parameters
  342. >>> m = nn.InstanceNorm3d(100, affine=True)
  343. >>> input = torch.randn(20, 100, 35, 45, 10)
  344. >>> output = m(input)
  345. """
  346. def _get_no_batch_dim(self) -> int:
  347. return 4
  348. def _check_input_dim(self, input) -> None:
  349. if input.dim() not in (4, 5):
  350. raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)")
  351. class LazyInstanceNorm3d(_LazyNormBase, _InstanceNorm):
  352. r"""A :class:`torch.nn.InstanceNorm3d` module with lazy initialization of the ``num_features`` argument.
  353. The ``num_features`` argument of the :class:`InstanceNorm3d` is inferred from the ``input.size(1)``.
  354. The attributes that will be lazily initialized are `weight`, `bias`,
  355. `running_mean` and `running_var`.
  356. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  357. on lazy modules and their limitations.
  358. Args:
  359. num_features: :math:`C` from an expected input of size
  360. :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
  361. eps: a value added to the denominator for numerical stability. Default: 1e-5
  362. momentum: the value used for the running_mean and running_var computation. Default: 0.1
  363. affine: a boolean value that when set to ``True``, this module has
  364. learnable affine parameters, initialized the same way as done for batch normalization.
  365. Default: ``False``.
  366. track_running_stats: a boolean value that when set to ``True``, this
  367. module tracks the running mean and variance, and when set to ``False``,
  368. this module does not track such statistics and always uses batch
  369. statistics in both training and eval modes. Default: ``False``
  370. Shape:
  371. - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`
  372. - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input)
  373. """
  374. cls_to_become = InstanceNorm3d # type: ignore[assignment]
  375. def _get_no_batch_dim(self) -> int:
  376. return 4
  377. def _check_input_dim(self, input) -> None:
  378. if input.dim() not in (4, 5):
  379. raise ValueError(f"expected 4D or 5D input (got {input.dim()}D input)")