batchnorm.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889
  1. # mypy: allow-untyped-defs
  2. from typing import Any, Optional
  3. import torch
  4. from torch import Tensor
  5. from torch.nn import functional as F, init
  6. from torch.nn.parameter import Parameter, UninitializedBuffer, UninitializedParameter
  7. from ._functions import SyncBatchNorm as sync_batch_norm
  8. from .lazy import LazyModuleMixin
  9. from .module import Module
  10. __all__ = [
  11. "BatchNorm1d",
  12. "LazyBatchNorm1d",
  13. "BatchNorm2d",
  14. "LazyBatchNorm2d",
  15. "BatchNorm3d",
  16. "LazyBatchNorm3d",
  17. "SyncBatchNorm",
  18. ]
  19. class _NormBase(Module):
  20. """Common base of _InstanceNorm and _BatchNorm."""
  21. _version = 2
  22. __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
  23. num_features: int
  24. eps: float
  25. momentum: Optional[float]
  26. affine: bool
  27. track_running_stats: bool
  28. # WARNING: weight and bias purposely not defined here.
  29. # See https://github.com/pytorch/pytorch/issues/39670
  30. def __init__(
  31. self,
  32. num_features: int,
  33. eps: float = 1e-5,
  34. momentum: Optional[float] = 0.1,
  35. affine: bool = True,
  36. track_running_stats: bool = True,
  37. device=None,
  38. dtype=None,
  39. ) -> None:
  40. factory_kwargs = {"device": device, "dtype": dtype}
  41. super().__init__()
  42. self.num_features = num_features
  43. self.eps = eps
  44. self.momentum = momentum
  45. self.affine = affine
  46. self.track_running_stats = track_running_stats
  47. if self.affine:
  48. self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
  49. self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
  50. else:
  51. self.register_parameter("weight", None)
  52. self.register_parameter("bias", None)
  53. if self.track_running_stats:
  54. self.register_buffer(
  55. "running_mean", torch.zeros(num_features, **factory_kwargs)
  56. )
  57. self.register_buffer(
  58. "running_var", torch.ones(num_features, **factory_kwargs)
  59. )
  60. self.running_mean: Optional[Tensor]
  61. self.running_var: Optional[Tensor]
  62. self.register_buffer(
  63. "num_batches_tracked",
  64. torch.tensor(
  65. 0,
  66. dtype=torch.long,
  67. **{k: v for k, v in factory_kwargs.items() if k != "dtype"},
  68. ),
  69. )
  70. self.num_batches_tracked: Optional[Tensor]
  71. else:
  72. self.register_buffer("running_mean", None)
  73. self.register_buffer("running_var", None)
  74. self.register_buffer("num_batches_tracked", None)
  75. self.reset_parameters()
  76. def reset_running_stats(self) -> None:
  77. if self.track_running_stats:
  78. # running_mean/running_var/num_batches... are registered at runtime depending
  79. # if self.track_running_stats is on
  80. self.running_mean.zero_() # type: ignore[union-attr]
  81. self.running_var.fill_(1) # type: ignore[union-attr]
  82. self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
  83. def reset_parameters(self) -> None:
  84. self.reset_running_stats()
  85. if self.affine:
  86. init.ones_(self.weight)
  87. init.zeros_(self.bias)
  88. def _check_input_dim(self, input):
  89. raise NotImplementedError
  90. def extra_repr(self):
  91. return (
  92. "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
  93. "track_running_stats={track_running_stats}".format(**self.__dict__)
  94. )
  95. def _load_from_state_dict(
  96. self,
  97. state_dict,
  98. prefix,
  99. local_metadata,
  100. strict,
  101. missing_keys,
  102. unexpected_keys,
  103. error_msgs,
  104. ) -> None:
  105. version = local_metadata.get("version", None)
  106. if (version is None or version < 2) and self.track_running_stats:
  107. # at version 2: added num_batches_tracked buffer
  108. # this should have a default value of 0
  109. num_batches_tracked_key = prefix + "num_batches_tracked"
  110. if num_batches_tracked_key not in state_dict:
  111. state_dict[num_batches_tracked_key] = (
  112. self.num_batches_tracked
  113. if self.num_batches_tracked is not None
  114. and self.num_batches_tracked.device != torch.device("meta")
  115. else torch.tensor(0, dtype=torch.long)
  116. )
  117. super()._load_from_state_dict(
  118. state_dict,
  119. prefix,
  120. local_metadata,
  121. strict,
  122. missing_keys,
  123. unexpected_keys,
  124. error_msgs,
  125. )
  126. class _BatchNorm(_NormBase):
  127. def __init__(
  128. self,
  129. num_features: int,
  130. eps: float = 1e-5,
  131. momentum: Optional[float] = 0.1,
  132. affine: bool = True,
  133. track_running_stats: bool = True,
  134. device=None,
  135. dtype=None,
  136. ) -> None:
  137. factory_kwargs = {"device": device, "dtype": dtype}
  138. super().__init__(
  139. num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
  140. )
  141. def forward(self, input: Tensor) -> Tensor:
  142. self._check_input_dim(input)
  143. # exponential_average_factor is set to self.momentum
  144. # (when it is available) only so that it gets updated
  145. # in ONNX graph when this node is exported to ONNX.
  146. if self.momentum is None:
  147. exponential_average_factor = 0.0
  148. else:
  149. exponential_average_factor = self.momentum
  150. if self.training and self.track_running_stats:
  151. # TODO: if statement only here to tell the jit to skip emitting this when it is None
  152. if self.num_batches_tracked is not None: # type: ignore[has-type]
  153. self.num_batches_tracked.add_(1) # type: ignore[has-type]
  154. if self.momentum is None: # use cumulative moving average
  155. exponential_average_factor = 1.0 / float(self.num_batches_tracked)
  156. else: # use exponential moving average
  157. exponential_average_factor = self.momentum
  158. r"""
  159. Decide whether the mini-batch stats should be used for normalization rather than the buffers.
  160. Mini-batch stats are used in training mode, and in eval mode when buffers are None.
  161. """
  162. if self.training:
  163. bn_training = True
  164. else:
  165. bn_training = (self.running_mean is None) and (self.running_var is None)
  166. r"""
  167. Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
  168. passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
  169. used for normalization (i.e. in eval mode when buffers are not None).
  170. """
  171. return F.batch_norm(
  172. input,
  173. # If buffers are not to be tracked, ensure that they won't be updated
  174. (
  175. self.running_mean
  176. if not self.training or self.track_running_stats
  177. else None
  178. ),
  179. self.running_var if not self.training or self.track_running_stats else None,
  180. self.weight,
  181. self.bias,
  182. bn_training,
  183. exponential_average_factor,
  184. self.eps,
  185. )
  186. class _LazyNormBase(LazyModuleMixin, _NormBase):
  187. weight: UninitializedParameter # type: ignore[assignment]
  188. bias: UninitializedParameter # type: ignore[assignment]
  189. def __init__(
  190. self,
  191. eps=1e-5,
  192. momentum=0.1,
  193. affine=True,
  194. track_running_stats=True,
  195. device=None,
  196. dtype=None,
  197. ) -> None:
  198. factory_kwargs = {"device": device, "dtype": dtype}
  199. super().__init__(
  200. # affine and track_running_stats are hardcoded to False to
  201. # avoid creating tensors that will soon be overwritten.
  202. 0,
  203. eps,
  204. momentum,
  205. False,
  206. False,
  207. **factory_kwargs,
  208. )
  209. self.affine = affine
  210. self.track_running_stats = track_running_stats
  211. if self.affine:
  212. self.weight = UninitializedParameter(**factory_kwargs)
  213. self.bias = UninitializedParameter(**factory_kwargs)
  214. if self.track_running_stats:
  215. self.running_mean = UninitializedBuffer(**factory_kwargs)
  216. self.running_var = UninitializedBuffer(**factory_kwargs)
  217. self.num_batches_tracked = torch.tensor(
  218. 0,
  219. dtype=torch.long,
  220. **{k: v for k, v in factory_kwargs.items() if k != "dtype"},
  221. )
  222. def reset_parameters(self) -> None:
  223. if not self.has_uninitialized_params() and self.num_features != 0:
  224. super().reset_parameters()
  225. def initialize_parameters(self, input) -> None: # type: ignore[override]
  226. if self.has_uninitialized_params():
  227. self.num_features = input.shape[1]
  228. if self.affine:
  229. assert isinstance(self.weight, UninitializedParameter)
  230. assert isinstance(self.bias, UninitializedParameter)
  231. self.weight.materialize((self.num_features,))
  232. self.bias.materialize((self.num_features,))
  233. if self.track_running_stats:
  234. self.running_mean.materialize( # type:ignore[union-attr]
  235. (self.num_features,)
  236. )
  237. self.running_var.materialize( # type:ignore[union-attr]
  238. (self.num_features,)
  239. )
  240. self.reset_parameters()
  241. class BatchNorm1d(_BatchNorm):
  242. r"""Applies Batch Normalization over a 2D or 3D input.
  243. Method described in the paper
  244. `Batch Normalization: Accelerating Deep Network Training by Reducing
  245. Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
  246. .. math::
  247. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  248. The mean and standard-deviation are calculated per-dimension over
  249. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  250. of size `C` (where `C` is the number of features or channels of the input). By default, the
  251. elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0.
  252. At train time in the forward pass, the variance is calculated via the biased estimator,
  253. equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the
  254. moving average of the variance is calculated via the unbiased estimator, equivalent to
  255. ``torch.var(input, unbiased=True)``.
  256. Also by default, during training this layer keeps running estimates of its
  257. computed mean and variance, which are then used for normalization during
  258. evaluation. The running estimates are kept with a default :attr:`momentum`
  259. of 0.1.
  260. If :attr:`track_running_stats` is set to ``False``, this layer then does not
  261. keep running estimates, and batch statistics are instead used during
  262. evaluation time as well.
  263. .. note::
  264. This :attr:`momentum` argument is different from one used in optimizer
  265. classes and the conventional notion of momentum. Mathematically, the
  266. update rule for running statistics here is
  267. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  268. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  269. new observed value.
  270. Because the Batch Normalization is done over the `C` dimension, computing statistics
  271. on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
  272. Args:
  273. num_features: number of features or channels :math:`C` of the input
  274. eps: a value added to the denominator for numerical stability.
  275. Default: 1e-5
  276. momentum: the value used for the running_mean and running_var
  277. computation. Can be set to ``None`` for cumulative moving average
  278. (i.e. simple average). Default: 0.1
  279. affine: a boolean value that when set to ``True``, this module has
  280. learnable affine parameters. Default: ``True``
  281. track_running_stats: a boolean value that when set to ``True``, this
  282. module tracks the running mean and variance, and when set to ``False``,
  283. this module does not track such statistics, and initializes statistics
  284. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  285. When these buffers are ``None``, this module always uses batch statistics.
  286. in both training and eval modes. Default: ``True``
  287. Shape:
  288. - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size,
  289. :math:`C` is the number of features or channels, and :math:`L` is the sequence length
  290. - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
  291. Examples::
  292. >>> # With Learnable Parameters
  293. >>> m = nn.BatchNorm1d(100)
  294. >>> # Without Learnable Parameters
  295. >>> m = nn.BatchNorm1d(100, affine=False)
  296. >>> input = torch.randn(20, 100)
  297. >>> output = m(input)
  298. """
  299. def _check_input_dim(self, input) -> None:
  300. if input.dim() != 2 and input.dim() != 3:
  301. raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
  302. class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
  303. r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization.
  304. Lazy initialization based on the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred
  305. from the ``input.size(1)``.
  306. The attributes that will be lazily initialized are `weight`, `bias`,
  307. `running_mean` and `running_var`.
  308. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  309. on lazy modules and their limitations.
  310. Args:
  311. eps: a value added to the denominator for numerical stability.
  312. Default: 1e-5
  313. momentum: the value used for the running_mean and running_var
  314. computation. Can be set to ``None`` for cumulative moving average
  315. (i.e. simple average). Default: 0.1
  316. affine: a boolean value that when set to ``True``, this module has
  317. learnable affine parameters. Default: ``True``
  318. track_running_stats: a boolean value that when set to ``True``, this
  319. module tracks the running mean and variance, and when set to ``False``,
  320. this module does not track such statistics, and initializes statistics
  321. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  322. When these buffers are ``None``, this module always uses batch statistics.
  323. in both training and eval modes. Default: ``True``
  324. """
  325. cls_to_become = BatchNorm1d # type: ignore[assignment]
  326. def _check_input_dim(self, input) -> None:
  327. if input.dim() != 2 and input.dim() != 3:
  328. raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
  329. class BatchNorm2d(_BatchNorm):
  330. r"""Applies Batch Normalization over a 4D input.
  331. 4D is a mini-batch of 2D inputs
  332. with additional channel dimension. Method described in the paper
  333. `Batch Normalization: Accelerating Deep Network Training by Reducing
  334. Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
  335. .. math::
  336. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  337. The mean and standard-deviation are calculated per-dimension over
  338. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  339. of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
  340. to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
  341. standard-deviation is calculated via the biased estimator, equivalent to
  342. ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
  343. standard-deviation is calculated via the unbiased estimator, equivalent to
  344. ``torch.var(input, unbiased=True)``.
  345. Also by default, during training this layer keeps running estimates of its
  346. computed mean and variance, which are then used for normalization during
  347. evaluation. The running estimates are kept with a default :attr:`momentum`
  348. of 0.1.
  349. If :attr:`track_running_stats` is set to ``False``, this layer then does not
  350. keep running estimates, and batch statistics are instead used during
  351. evaluation time as well.
  352. .. note::
  353. This :attr:`momentum` argument is different from one used in optimizer
  354. classes and the conventional notion of momentum. Mathematically, the
  355. update rule for running statistics here is
  356. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  357. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  358. new observed value.
  359. Because the Batch Normalization is done over the `C` dimension, computing statistics
  360. on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
  361. Args:
  362. num_features: :math:`C` from an expected input of size
  363. :math:`(N, C, H, W)`
  364. eps: a value added to the denominator for numerical stability.
  365. Default: 1e-5
  366. momentum: the value used for the running_mean and running_var
  367. computation. Can be set to ``None`` for cumulative moving average
  368. (i.e. simple average). Default: 0.1
  369. affine: a boolean value that when set to ``True``, this module has
  370. learnable affine parameters. Default: ``True``
  371. track_running_stats: a boolean value that when set to ``True``, this
  372. module tracks the running mean and variance, and when set to ``False``,
  373. this module does not track such statistics, and initializes statistics
  374. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  375. When these buffers are ``None``, this module always uses batch statistics.
  376. in both training and eval modes. Default: ``True``
  377. Shape:
  378. - Input: :math:`(N, C, H, W)`
  379. - Output: :math:`(N, C, H, W)` (same shape as input)
  380. Examples::
  381. >>> # With Learnable Parameters
  382. >>> m = nn.BatchNorm2d(100)
  383. >>> # Without Learnable Parameters
  384. >>> m = nn.BatchNorm2d(100, affine=False)
  385. >>> input = torch.randn(20, 100, 35, 45)
  386. >>> output = m(input)
  387. """
  388. def _check_input_dim(self, input) -> None:
  389. if input.dim() != 4:
  390. raise ValueError(f"expected 4D input (got {input.dim()}D input)")
  391. class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
  392. r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization.
  393. Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred
  394. from the ``input.size(1)``.
  395. The attributes that will be lazily initialized are `weight`, `bias`,
  396. `running_mean` and `running_var`.
  397. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  398. on lazy modules and their limitations.
  399. Args:
  400. eps: a value added to the denominator for numerical stability.
  401. Default: 1e-5
  402. momentum: the value used for the running_mean and running_var
  403. computation. Can be set to ``None`` for cumulative moving average
  404. (i.e. simple average). Default: 0.1
  405. affine: a boolean value that when set to ``True``, this module has
  406. learnable affine parameters. Default: ``True``
  407. track_running_stats: a boolean value that when set to ``True``, this
  408. module tracks the running mean and variance, and when set to ``False``,
  409. this module does not track such statistics, and initializes statistics
  410. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  411. When these buffers are ``None``, this module always uses batch statistics.
  412. in both training and eval modes. Default: ``True``
  413. """
  414. cls_to_become = BatchNorm2d # type: ignore[assignment]
  415. def _check_input_dim(self, input) -> None:
  416. if input.dim() != 4:
  417. raise ValueError(f"expected 4D input (got {input.dim()}D input)")
  418. class BatchNorm3d(_BatchNorm):
  419. r"""Applies Batch Normalization over a 5D input.
  420. 5D is a mini-batch of 3D inputs with additional channel dimension as described in the paper
  421. `Batch Normalization: Accelerating Deep Network Training by Reducing
  422. Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
  423. .. math::
  424. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  425. The mean and standard-deviation are calculated per-dimension over
  426. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  427. of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
  428. to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
  429. standard-deviation is calculated via the biased estimator, equivalent to
  430. ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
  431. standard-deviation is calculated via the unbiased estimator, equivalent to
  432. ``torch.var(input, unbiased=True)``.
  433. Also by default, during training this layer keeps running estimates of its
  434. computed mean and variance, which are then used for normalization during
  435. evaluation. The running estimates are kept with a default :attr:`momentum`
  436. of 0.1.
  437. If :attr:`track_running_stats` is set to ``False``, this layer then does not
  438. keep running estimates, and batch statistics are instead used during
  439. evaluation time as well.
  440. .. note::
  441. This :attr:`momentum` argument is different from one used in optimizer
  442. classes and the conventional notion of momentum. Mathematically, the
  443. update rule for running statistics here is
  444. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  445. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  446. new observed value.
  447. Because the Batch Normalization is done over the `C` dimension, computing statistics
  448. on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
  449. or Spatio-temporal Batch Normalization.
  450. Args:
  451. num_features: :math:`C` from an expected input of size
  452. :math:`(N, C, D, H, W)`
  453. eps: a value added to the denominator for numerical stability.
  454. Default: 1e-5
  455. momentum: the value used for the running_mean and running_var
  456. computation. Can be set to ``None`` for cumulative moving average
  457. (i.e. simple average). Default: 0.1
  458. affine: a boolean value that when set to ``True``, this module has
  459. learnable affine parameters. Default: ``True``
  460. track_running_stats: a boolean value that when set to ``True``, this
  461. module tracks the running mean and variance, and when set to ``False``,
  462. this module does not track such statistics, and initializes statistics
  463. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  464. When these buffers are ``None``, this module always uses batch statistics.
  465. in both training and eval modes. Default: ``True``
  466. Shape:
  467. - Input: :math:`(N, C, D, H, W)`
  468. - Output: :math:`(N, C, D, H, W)` (same shape as input)
  469. Examples::
  470. >>> # With Learnable Parameters
  471. >>> m = nn.BatchNorm3d(100)
  472. >>> # Without Learnable Parameters
  473. >>> m = nn.BatchNorm3d(100, affine=False)
  474. >>> input = torch.randn(20, 100, 35, 45, 10)
  475. >>> output = m(input)
  476. """
  477. def _check_input_dim(self, input) -> None:
  478. if input.dim() != 5:
  479. raise ValueError(f"expected 5D input (got {input.dim()}D input)")
  480. class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
  481. r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization.
  482. Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred
  483. from the ``input.size(1)``.
  484. The attributes that will be lazily initialized are `weight`, `bias`,
  485. `running_mean` and `running_var`.
  486. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  487. on lazy modules and their limitations.
  488. Args:
  489. eps: a value added to the denominator for numerical stability.
  490. Default: 1e-5
  491. momentum: the value used for the running_mean and running_var
  492. computation. Can be set to ``None`` for cumulative moving average
  493. (i.e. simple average). Default: 0.1
  494. affine: a boolean value that when set to ``True``, this module has
  495. learnable affine parameters. Default: ``True``
  496. track_running_stats: a boolean value that when set to ``True``, this
  497. module tracks the running mean and variance, and when set to ``False``,
  498. this module does not track such statistics, and initializes statistics
  499. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  500. When these buffers are ``None``, this module always uses batch statistics.
  501. in both training and eval modes. Default: ``True``
  502. """
  503. cls_to_become = BatchNorm3d # type: ignore[assignment]
  504. def _check_input_dim(self, input) -> None:
  505. if input.dim() != 5:
  506. raise ValueError(f"expected 5D input (got {input.dim()}D input)")
  507. class SyncBatchNorm(_BatchNorm):
  508. r"""Applies Batch Normalization over a N-Dimensional input.
  509. The N-D input is a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper
  510. `Batch Normalization: Accelerating Deep Network Training by Reducing
  511. Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
  512. .. math::
  513. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  514. The mean and standard-deviation are calculated per-dimension over all
  515. mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
  516. are learnable parameter vectors of size `C` (where `C` is the input size).
  517. By default, the elements of :math:`\gamma` are sampled from
  518. :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
  519. The standard-deviation is calculated via the biased estimator, equivalent to
  520. `torch.var(input, unbiased=False)`.
  521. Also by default, during training this layer keeps running estimates of its
  522. computed mean and variance, which are then used for normalization during
  523. evaluation. The running estimates are kept with a default :attr:`momentum`
  524. of 0.1.
  525. If :attr:`track_running_stats` is set to ``False``, this layer then does not
  526. keep running estimates, and batch statistics are instead used during
  527. evaluation time as well.
  528. .. note::
  529. This :attr:`momentum` argument is different from one used in optimizer
  530. classes and the conventional notion of momentum. Mathematically, the
  531. update rule for running statistics here is
  532. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  533. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  534. new observed value.
  535. Because the Batch Normalization is done for each channel in the ``C`` dimension, computing
  536. statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch
  537. Normalization or Spatio-temporal Batch Normalization.
  538. Currently :class:`SyncBatchNorm` only supports
  539. :class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use
  540. :meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert
  541. :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping
  542. Network with DDP.
  543. Args:
  544. num_features: :math:`C` from an expected input of size
  545. :math:`(N, C, +)`
  546. eps: a value added to the denominator for numerical stability.
  547. Default: ``1e-5``
  548. momentum: the value used for the running_mean and running_var
  549. computation. Can be set to ``None`` for cumulative moving average
  550. (i.e. simple average). Default: 0.1
  551. affine: a boolean value that when set to ``True``, this module has
  552. learnable affine parameters. Default: ``True``
  553. track_running_stats: a boolean value that when set to ``True``, this
  554. module tracks the running mean and variance, and when set to ``False``,
  555. this module does not track such statistics, and initializes statistics
  556. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  557. When these buffers are ``None``, this module always uses batch statistics.
  558. in both training and eval modes. Default: ``True``
  559. process_group: synchronization of stats happen within each process group
  560. individually. Default behavior is synchronization across the whole
  561. world
  562. Shape:
  563. - Input: :math:`(N, C, +)`
  564. - Output: :math:`(N, C, +)` (same shape as input)
  565. .. note::
  566. Synchronization of batchnorm statistics occurs only while training, i.e.
  567. synchronization is disabled when ``model.eval()`` is set or if
  568. ``self.training`` is otherwise ``False``.
  569. Examples::
  570. >>> # xdoctest: +SKIP
  571. >>> # With Learnable Parameters
  572. >>> m = nn.SyncBatchNorm(100)
  573. >>> # creating process group (optional)
  574. >>> # ranks is a list of int identifying rank ids.
  575. >>> ranks = list(range(8))
  576. >>> r1, r2 = ranks[:4], ranks[4:]
  577. >>> # Note: every rank calls into new_group for every
  578. >>> # process group created, even if that rank is not
  579. >>> # part of the group.
  580. >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
  581. >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
  582. >>> # Without Learnable Parameters
  583. >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
  584. >>> input = torch.randn(20, 100, 35, 45, 10)
  585. >>> output = m(input)
  586. >>> # network is nn.BatchNorm layer
  587. >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
  588. >>> # only single gpu per process is currently supported
  589. >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
  590. >>> sync_bn_network,
  591. >>> device_ids=[args.local_rank],
  592. >>> output_device=args.local_rank)
  593. """
  594. def __init__(
  595. self,
  596. num_features: int,
  597. eps: float = 1e-5,
  598. momentum: Optional[float] = 0.1,
  599. affine: bool = True,
  600. track_running_stats: bool = True,
  601. process_group: Optional[Any] = None,
  602. device=None,
  603. dtype=None,
  604. ) -> None:
  605. factory_kwargs = {"device": device, "dtype": dtype}
  606. super().__init__(
  607. num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
  608. )
  609. self.process_group = process_group
  610. def _check_input_dim(self, input) -> None:
  611. if input.dim() < 2:
  612. raise ValueError(f"expected at least 2D input (got {input.dim()}D input)")
  613. def _check_non_zero_input_channels(self, input) -> None:
  614. if input.size(1) == 0:
  615. raise ValueError(
  616. "SyncBatchNorm number of input channels should be non-zero"
  617. )
  618. def forward(self, input: Tensor) -> Tensor:
  619. """
  620. Runs the forward pass.
  621. """
  622. self._check_input_dim(input)
  623. self._check_non_zero_input_channels(input)
  624. # exponential_average_factor is set to self.momentum
  625. # (when it is available) only so that it gets updated
  626. # in ONNX graph when this node is exported to ONNX.
  627. if self.momentum is None:
  628. exponential_average_factor = 0.0
  629. else:
  630. exponential_average_factor = self.momentum
  631. if self.training and self.track_running_stats:
  632. assert self.num_batches_tracked is not None
  633. self.num_batches_tracked.add_(1)
  634. if self.momentum is None: # use cumulative moving average
  635. exponential_average_factor = 1.0 / self.num_batches_tracked.item()
  636. else: # use exponential moving average
  637. exponential_average_factor = self.momentum
  638. r"""
  639. Decide whether the mini-batch stats should be used for normalization rather than the buffers.
  640. Mini-batch stats are used in training mode, and in eval mode when buffers are None.
  641. """
  642. if self.training:
  643. bn_training = True
  644. else:
  645. bn_training = (self.running_mean is None) and (self.running_var is None)
  646. r"""
  647. Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
  648. passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
  649. used for normalization (i.e. in eval mode when buffers are not None).
  650. """
  651. # If buffers are not to be tracked, ensure that they won't be updated
  652. running_mean = (
  653. self.running_mean if not self.training or self.track_running_stats else None
  654. )
  655. running_var = (
  656. self.running_var if not self.training or self.track_running_stats else None
  657. )
  658. # Don't sync batchnorm stats in inference mode (model.eval()).
  659. need_sync = (
  660. bn_training
  661. and self.training
  662. and torch.distributed.is_available()
  663. and torch.distributed.is_initialized()
  664. )
  665. if need_sync:
  666. # currently only GPU/PrivateUse1 input is supported
  667. if input.device.type not in [
  668. "cuda",
  669. "xpu",
  670. torch._C._get_privateuse1_backend_name(),
  671. ]:
  672. raise ValueError(
  673. "SyncBatchNorm expected input tensor to be on GPU or XPU or "
  674. f"{torch._C._get_privateuse1_backend_name()}"
  675. )
  676. process_group = torch.distributed.group.WORLD
  677. if self.process_group:
  678. process_group = self.process_group
  679. world_size = torch.distributed.get_world_size(process_group)
  680. need_sync = world_size > 1
  681. # fallback to framework BN when synchronization is not necessary
  682. if not need_sync:
  683. return F.batch_norm(
  684. input,
  685. running_mean,
  686. running_var,
  687. self.weight,
  688. self.bias,
  689. bn_training,
  690. exponential_average_factor,
  691. self.eps,
  692. )
  693. else:
  694. assert bn_training
  695. return sync_batch_norm.apply(
  696. input,
  697. self.weight,
  698. self.bias,
  699. running_mean,
  700. running_var,
  701. self.eps,
  702. exponential_average_factor,
  703. process_group, # type: ignore[possibly-undefined]
  704. world_size, # type: ignore[possibly-undefined]
  705. )
  706. @classmethod
  707. def convert_sync_batchnorm(cls, module, process_group=None):
  708. r"""Converts all :attr:`BatchNorm*D` layers in the model to :class:`torch.nn.SyncBatchNorm` layers.
  709. Args:
  710. module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers
  711. process_group (optional): process group to scope synchronization,
  712. default is the whole world
  713. Returns:
  714. The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`
  715. layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,
  716. a new :class:`torch.nn.SyncBatchNorm` layer object will be returned
  717. instead.
  718. Example::
  719. >>> # Network with nn.BatchNorm layer
  720. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  721. >>> module = torch.nn.Sequential(
  722. >>> torch.nn.Linear(20, 100),
  723. >>> torch.nn.BatchNorm1d(100),
  724. >>> ).cuda()
  725. >>> # creating process group (optional)
  726. >>> # ranks is a list of int identifying rank ids.
  727. >>> ranks = list(range(8))
  728. >>> r1, r2 = ranks[:4], ranks[4:]
  729. >>> # Note: every rank calls into new_group for every
  730. >>> # process group created, even if that rank is not
  731. >>> # part of the group.
  732. >>> # xdoctest: +SKIP("distributed")
  733. >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
  734. >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
  735. >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
  736. """
  737. module_output = module
  738. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  739. module_output = torch.nn.SyncBatchNorm(
  740. module.num_features,
  741. module.eps,
  742. module.momentum,
  743. module.affine,
  744. module.track_running_stats,
  745. process_group,
  746. )
  747. if module.affine:
  748. with torch.no_grad():
  749. module_output.weight = module.weight
  750. module_output.bias = module.bias
  751. module_output.running_mean = module.running_mean
  752. module_output.running_var = module.running_var
  753. module_output.num_batches_tracked = module.num_batches_tracked
  754. module_output.training = module.training
  755. if hasattr(module, "qconfig"):
  756. module_output.qconfig = module.qconfig
  757. for name, child in module.named_children():
  758. module_output.add_module(
  759. name, cls.convert_sync_batchnorm(child, process_group)
  760. )
  761. del module
  762. return module_output