swa_utils.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. # mypy: allow-untyped-defs
  2. r"""Implementation for Stochastic Weight Averaging implementation."""
  3. import itertools
  4. import math
  5. import warnings
  6. from collections.abc import Iterable
  7. from copy import deepcopy
  8. from typing import Any, Callable, cast, Literal, Optional, Union
  9. import torch
  10. from torch import Tensor
  11. from torch.nn import Module
  12. from torch.optim.lr_scheduler import _format_param, LRScheduler
  13. from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
  14. from .optimizer import Optimizer
  15. __all__ = [
  16. "AveragedModel",
  17. "update_bn",
  18. "SWALR",
  19. "get_ema_multi_avg_fn",
  20. "get_swa_multi_avg_fn",
  21. "get_ema_avg_fn",
  22. "get_swa_avg_fn",
  23. ]
  24. from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
  25. PARAM_LIST = Union[tuple[Tensor, ...], list[Tensor]]
  26. def get_ema_multi_avg_fn(decay=0.999):
  27. """Get the function applying exponential moving average (EMA) across multiple params."""
  28. if decay < 0.0 or decay > 1.0:
  29. raise ValueError(
  30. f"Invalid decay value {decay} provided. Please provide a value in [0,1] range."
  31. )
  32. @torch.no_grad()
  33. def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _):
  34. # foreach lerp only handles float and complex
  35. if torch.is_floating_point(ema_param_list[0]) or torch.is_complex(
  36. ema_param_list[0]
  37. ):
  38. torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay)
  39. else:
  40. for p_ema, p_model in zip(ema_param_list, current_param_list):
  41. p_ema.copy_(p_ema * decay + p_model * (1 - decay))
  42. return ema_update
  43. def get_swa_multi_avg_fn():
  44. """Get the function applying stochastic weight average (SWA) across multiple params."""
  45. @torch.no_grad()
  46. def swa_update(
  47. averaged_param_list: PARAM_LIST,
  48. current_param_list: PARAM_LIST,
  49. num_averaged: Union[Tensor, int],
  50. ):
  51. # foreach lerp only handles float and complex
  52. if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex(
  53. averaged_param_list[0]
  54. ):
  55. torch._foreach_lerp_(
  56. averaged_param_list,
  57. current_param_list,
  58. cast(float, 1 / (num_averaged + 1)),
  59. )
  60. else:
  61. diffs = torch._foreach_sub(current_param_list, averaged_param_list)
  62. if isinstance(num_averaged, Tensor):
  63. torch._foreach_addcdiv_(
  64. averaged_param_list,
  65. diffs,
  66. [num_averaged + 1] * len(averaged_param_list),
  67. )
  68. else:
  69. torch._foreach_add_(
  70. averaged_param_list, diffs, alpha=1.0 / (num_averaged + 1)
  71. )
  72. return swa_update
  73. def get_ema_avg_fn(decay=0.999):
  74. """Get the function applying exponential moving average (EMA) across a single param."""
  75. if decay < 0.0 or decay > 1.0:
  76. raise ValueError(
  77. f"Invalid decay value {decay} provided. Please provide a value in [0,1] range."
  78. )
  79. @torch.no_grad()
  80. def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged):
  81. return decay * ema_param + (1 - decay) * current_param
  82. return ema_update
  83. def get_swa_avg_fn():
  84. """Get the function applying stochastic weight average (SWA) across a single param."""
  85. @torch.no_grad()
  86. def swa_update(
  87. averaged_param: Tensor, current_param: Tensor, num_averaged: Union[Tensor, int]
  88. ):
  89. return averaged_param + (current_param - averaged_param) / (num_averaged + 1)
  90. return swa_update
  91. class AveragedModel(Module):
  92. r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA).
  93. Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
  94. Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
  95. Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
  96. (UAI 2018).
  97. Exponential Moving Average is a variation of `Polyak averaging`_,
  98. but using exponential weights instead of equal weights across iterations.
  99. AveragedModel class creates a copy of the provided module :attr:`model`
  100. on the device :attr:`device` and allows to compute running averages of the
  101. parameters of the :attr:`model`.
  102. Args:
  103. model (torch.nn.Module): model to use with SWA/EMA
  104. device (torch.device, optional): if provided, the averaged model will be
  105. stored on the :attr:`device`
  106. avg_fn (function, optional): the averaging function used to update
  107. parameters; the function must take in the current value of the
  108. :class:`AveragedModel` parameter, the current value of :attr:`model`
  109. parameter, and the number of models already averaged; if None,
  110. an equally weighted average is used (default: None)
  111. multi_avg_fn (function, optional): the averaging function used to update
  112. parameters inplace; the function must take in the current values of the
  113. :class:`AveragedModel` parameters as a list, the current values of :attr:`model`
  114. parameters as a list, and the number of models already averaged; if None,
  115. an equally weighted average is used (default: None)
  116. use_buffers (bool): if ``True``, it will compute running averages for
  117. both the parameters and the buffers of the model. (default: ``False``)
  118. Example:
  119. >>> # xdoctest: +SKIP("undefined variables")
  120. >>> loader, optimizer, model, loss_fn = ...
  121. >>> swa_model = torch.optim.swa_utils.AveragedModel(model)
  122. >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
  123. >>> T_max=300)
  124. >>> swa_start = 160
  125. >>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
  126. >>> for i in range(300):
  127. >>> for input, target in loader:
  128. >>> optimizer.zero_grad()
  129. >>> loss_fn(model(input), target).backward()
  130. >>> optimizer.step()
  131. >>> if i > swa_start:
  132. >>> swa_model.update_parameters(model)
  133. >>> swa_scheduler.step()
  134. >>> else:
  135. >>> scheduler.step()
  136. >>>
  137. >>> # Update bn statistics for the swa_model at the end
  138. >>> torch.optim.swa_utils.update_bn(loader, swa_model)
  139. You can also use custom averaging functions with the `avg_fn` or `multi_avg_fn` parameters.
  140. If no averaging function is provided, the default is to compute
  141. equally-weighted average of the weights (SWA).
  142. Example:
  143. >>> # xdoctest: +SKIP("undefined variables")
  144. >>> # Compute exponential moving averages of the weights and buffers
  145. >>> ema_model = torch.optim.swa_utils.AveragedModel(model,
  146. >>> torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True)
  147. .. note::
  148. When using SWA/EMA with models containing Batch Normalization you may
  149. need to update the activation statistics for Batch Normalization.
  150. This can be done either by using the :meth:`torch.optim.swa_utils.update_bn`
  151. or by setting :attr:`use_buffers` to `True`. The first approach updates the
  152. statistics in a post-training step by passing data through the model. The
  153. second does it during the parameter update phase by averaging all buffers.
  154. Empirical evidence has shown that updating the statistics in normalization
  155. layers increases accuracy, but you may wish to empirically test which
  156. approach yields the best results in your problem.
  157. .. note::
  158. :attr:`avg_fn` and `multi_avg_fn` are not saved in the :meth:`state_dict` of the model.
  159. .. note::
  160. When :meth:`update_parameters` is called for the first time (i.e.
  161. :attr:`n_averaged` is `0`) the parameters of `model` are copied
  162. to the parameters of :class:`AveragedModel`. For every subsequent
  163. call of :meth:`update_parameters` the function `avg_fn` is used
  164. to update the parameters.
  165. .. _Averaging Weights Leads to Wider Optima and Better Generalization:
  166. https://arxiv.org/abs/1803.05407
  167. .. _There Are Many Consistent Explanations of Unlabeled Data: Why You Should
  168. Average:
  169. https://arxiv.org/abs/1806.05594
  170. .. _SWALP: Stochastic Weight Averaging in Low-Precision Training:
  171. https://arxiv.org/abs/1904.11943
  172. .. _Stochastic Weight Averaging in Parallel: Large-Batch Training That
  173. Generalizes Well:
  174. https://arxiv.org/abs/2001.02312
  175. .. _Polyak averaging:
  176. https://paperswithcode.com/method/polyak-averaging
  177. """
  178. n_averaged: Tensor
  179. def __init__(
  180. self,
  181. model: Module,
  182. device: Optional[Union[int, torch.device]] = None,
  183. avg_fn: Optional[Callable[[Tensor, Tensor, Union[Tensor, int]], Tensor]] = None,
  184. multi_avg_fn: Optional[
  185. Callable[[PARAM_LIST, PARAM_LIST, Union[Tensor, int]], None]
  186. ] = None,
  187. use_buffers=False,
  188. ): # noqa: D107
  189. super().__init__()
  190. assert avg_fn is None or multi_avg_fn is None, (
  191. "Only one of avg_fn and multi_avg_fn should be provided"
  192. )
  193. self.module = deepcopy(model)
  194. if device is not None:
  195. self.module = self.module.to(device)
  196. self.register_buffer(
  197. "n_averaged", torch.tensor(0, dtype=torch.long, device=device)
  198. )
  199. self.avg_fn = avg_fn
  200. self.multi_avg_fn = multi_avg_fn
  201. self.use_buffers = use_buffers
  202. def forward(self, *args, **kwargs):
  203. """Forward pass."""
  204. return self.module(*args, **kwargs)
  205. def update_parameters(self, model: Module):
  206. """Update model parameters."""
  207. self_param = (
  208. itertools.chain(self.module.parameters(), self.module.buffers())
  209. if self.use_buffers
  210. else self.parameters()
  211. )
  212. model_param = (
  213. itertools.chain(model.parameters(), model.buffers())
  214. if self.use_buffers
  215. else model.parameters()
  216. )
  217. self_param_detached: list[Optional[Tensor]] = []
  218. model_param_detached: list[Optional[Tensor]] = []
  219. copy_param = bool(self.n_averaged == 0)
  220. for p_averaged, p_model in zip(self_param, model_param):
  221. p_model_ = p_model.detach().to(p_averaged.device)
  222. self_param_detached.append(p_averaged.detach())
  223. model_param_detached.append(p_model_)
  224. if copy_param:
  225. p_averaged.detach().copy_(p_model_)
  226. if self.n_averaged > 0:
  227. if self.multi_avg_fn is not None or self.avg_fn is None:
  228. grouped_tensors = _group_tensors_by_device_and_dtype(
  229. [self_param_detached, model_param_detached]
  230. )
  231. for (device, _), (
  232. [self_params, model_params],
  233. _,
  234. ) in grouped_tensors.items():
  235. if self.multi_avg_fn:
  236. self.multi_avg_fn(
  237. self_params, # type: ignore[arg-type]
  238. model_params, # type: ignore[arg-type]
  239. self.n_averaged.to(device),
  240. )
  241. elif (
  242. device is not None
  243. and device.type in _get_foreach_kernels_supported_devices()
  244. ):
  245. multi_avg_fn = get_swa_multi_avg_fn()
  246. multi_avg_fn(
  247. self_params, model_params, self.n_averaged.to(device)
  248. )
  249. else:
  250. avg_fn = get_swa_avg_fn()
  251. n_averaged = self.n_averaged.to(device)
  252. for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
  253. p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
  254. else:
  255. for p_averaged, p_model in zip( # type: ignore[assignment]
  256. self_param_detached, model_param_detached
  257. ):
  258. n_averaged = self.n_averaged.to(p_averaged.device)
  259. p_averaged.detach().copy_(
  260. self.avg_fn(p_averaged.detach(), p_model, n_averaged)
  261. )
  262. if not self.use_buffers:
  263. # If not apply running averages to the buffers,
  264. # keep the buffers in sync with the source model.
  265. for b_swa, b_model in zip(self.module.buffers(), model.buffers()):
  266. b_swa.detach().copy_(b_model.detach().to(b_swa.device))
  267. self.n_averaged += 1
  268. @torch.no_grad()
  269. def update_bn(
  270. loader: Iterable[Any],
  271. model: Module,
  272. device: Optional[Union[int, torch.device]] = None,
  273. ):
  274. r"""Update BatchNorm running_mean, running_var buffers in the model.
  275. It performs one pass over data in `loader` to estimate the activation
  276. statistics for BatchNorm layers in the model.
  277. Args:
  278. loader (torch.utils.data.DataLoader): dataset loader to compute the
  279. activation statistics on. Each data batch should be either a
  280. tensor, or a list/tuple whose first element is a tensor
  281. containing data.
  282. model (torch.nn.Module): model for which we seek to update BatchNorm
  283. statistics.
  284. device (torch.device, optional): If set, data will be transferred to
  285. :attr:`device` before being passed into :attr:`model`.
  286. Example:
  287. >>> # xdoctest: +SKIP("Undefined variables")
  288. >>> loader, model = ...
  289. >>> torch.optim.swa_utils.update_bn(loader, model)
  290. .. note::
  291. The `update_bn` utility assumes that each data batch in :attr:`loader`
  292. is either a tensor or a list or tuple of tensors; in the latter case it
  293. is assumed that :meth:`model.forward()` should be called on the first
  294. element of the list or tuple corresponding to the data batch.
  295. """
  296. momenta = {}
  297. for module in model.modules():
  298. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  299. module.reset_running_stats()
  300. momenta[module] = module.momentum
  301. if not momenta:
  302. return
  303. was_training = model.training
  304. model.train()
  305. for module in momenta.keys():
  306. module.momentum = None
  307. for input in loader:
  308. if isinstance(input, (list, tuple)):
  309. input = input[0]
  310. if device is not None:
  311. input = input.to(device)
  312. model(input)
  313. for bn_module in momenta.keys():
  314. bn_module.momentum = momenta[bn_module]
  315. model.train(was_training)
  316. class SWALR(LRScheduler):
  317. r"""Anneals the learning rate in each parameter group to a fixed value.
  318. This learning rate scheduler is meant to be used with Stochastic Weight
  319. Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`).
  320. Args:
  321. optimizer (torch.optim.Optimizer): wrapped optimizer
  322. swa_lrs (float or list): the learning rate value for all param groups
  323. together or separately for each group.
  324. annealing_epochs (int): number of epochs in the annealing phase
  325. (default: 10)
  326. annealing_strategy (str): "cos" or "linear"; specifies the annealing
  327. strategy: "cos" for cosine annealing, "linear" for linear annealing
  328. (default: "cos")
  329. last_epoch (int): the index of the last epoch (default: -1)
  330. The :class:`SWALR` scheduler can be used together with other
  331. schedulers to switch to a constant learning rate late in the training
  332. as in the example below.
  333. Example:
  334. >>> # xdoctest: +SKIP("Undefined variables")
  335. >>> loader, optimizer, model = ...
  336. >>> lr_lambda = lambda epoch: 0.9
  337. >>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,
  338. >>> lr_lambda=lr_lambda)
  339. >>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer,
  340. >>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05)
  341. >>> swa_start = 160
  342. >>> for i in range(300):
  343. >>> for input, target in loader:
  344. >>> optimizer.zero_grad()
  345. >>> loss_fn(model(input), target).backward()
  346. >>> optimizer.step()
  347. >>> if i > swa_start:
  348. >>> swa_scheduler.step()
  349. >>> else:
  350. >>> scheduler.step()
  351. .. _Averaging Weights Leads to Wider Optima and Better Generalization:
  352. https://arxiv.org/abs/1803.05407
  353. """
  354. def __init__(
  355. self,
  356. optimizer: Optimizer,
  357. swa_lr: float,
  358. anneal_epochs=10,
  359. anneal_strategy: Literal["cos", "linear"] = "cos",
  360. last_epoch=-1,
  361. ): # noqa: D107
  362. swa_lrs = _format_param("swa_lr", optimizer, swa_lr)
  363. for swa_lr, group in zip(swa_lrs, optimizer.param_groups):
  364. group["swa_lr"] = swa_lr
  365. if anneal_strategy not in ["cos", "linear"]:
  366. raise ValueError(
  367. "anneal_strategy must by one of 'cos' or 'linear', "
  368. f"instead got {anneal_strategy}"
  369. )
  370. elif anneal_strategy == "cos":
  371. self.anneal_func = self._cosine_anneal
  372. elif anneal_strategy == "linear":
  373. self.anneal_func = self._linear_anneal
  374. if not isinstance(anneal_epochs, int) or anneal_epochs < 0:
  375. raise ValueError(
  376. f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}"
  377. )
  378. self.anneal_epochs = anneal_epochs
  379. super().__init__(optimizer, last_epoch)
  380. @staticmethod
  381. def _linear_anneal(t):
  382. return t
  383. @staticmethod
  384. def _cosine_anneal(t):
  385. return (1 - math.cos(math.pi * t)) / 2
  386. @staticmethod
  387. def _get_initial_lr(lr, swa_lr, alpha):
  388. if alpha == 1:
  389. return swa_lr
  390. return (lr - alpha * swa_lr) / (1 - alpha)
  391. def get_lr(self):
  392. """Get learning rate."""
  393. # `_get_lr_called_within_step` is only available `_enable_get_lr_call`,
  394. # so we ignore the type error here. See `LRScheduler.step()` for more details.
  395. if not self._get_lr_called_within_step:
  396. warnings.warn(
  397. "To get the last learning rate computed by the scheduler, "
  398. "please use `get_last_lr()`.",
  399. UserWarning,
  400. )
  401. # Set in `LRScheduler._initial_step()`
  402. step = self._step_count - 1
  403. if self.anneal_epochs == 0:
  404. step = max(1, step)
  405. prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
  406. prev_alpha = self.anneal_func(prev_t)
  407. prev_lrs = [
  408. self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha)
  409. for group in self.optimizer.param_groups
  410. ]
  411. t = max(0, min(1, step / max(1, self.anneal_epochs)))
  412. alpha = self.anneal_func(t)
  413. return [
  414. group["swa_lr"] * alpha + lr * (1 - alpha)
  415. for group, lr in zip(self.optimizer.param_groups, prev_lrs)
  416. ]