swa_utils.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. # mypy: allow-untyped-defs
  2. r"""Implementation for Stochastic Weight Averaging implementation."""
  3. import itertools
  4. import math
  5. import warnings
  6. from collections.abc import Callable, Iterable
  7. from copy import deepcopy
  8. from typing import Any, cast, Literal, Union
  9. from typing_extensions import override
  10. import torch
  11. from torch import Tensor
  12. from torch.nn import Module
  13. from torch.optim.lr_scheduler import _format_param, LRScheduler
  14. from torch.utils._foreach_utils import _get_foreach_kernels_supported_devices
  15. from .optimizer import Optimizer
  16. __all__ = [
  17. "AveragedModel",
  18. "update_bn",
  19. "SWALR",
  20. "get_ema_multi_avg_fn",
  21. "get_swa_multi_avg_fn",
  22. "get_ema_avg_fn",
  23. "get_swa_avg_fn",
  24. ]
  25. from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
  26. PARAM_LIST = Union[tuple[Tensor, ...], list[Tensor]]
  27. def get_ema_multi_avg_fn(decay=0.999):
  28. """Get the function applying exponential moving average (EMA) across multiple params."""
  29. if decay < 0.0 or decay > 1.0:
  30. raise ValueError(
  31. f"Invalid decay value {decay} provided. Please provide a value in [0,1] range."
  32. )
  33. @torch.no_grad()
  34. def ema_update(
  35. ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _
  36. ) -> None:
  37. # foreach lerp only handles float and complex
  38. if torch.is_floating_point(ema_param_list[0]) or torch.is_complex(
  39. ema_param_list[0]
  40. ):
  41. torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay)
  42. else:
  43. for p_ema, p_model in zip(ema_param_list, current_param_list, strict=True):
  44. p_ema.copy_(p_ema * decay + p_model * (1 - decay))
  45. return ema_update
  46. def get_swa_multi_avg_fn():
  47. """Get the function applying stochastic weight average (SWA) across multiple params."""
  48. @torch.no_grad()
  49. def swa_update(
  50. averaged_param_list: PARAM_LIST,
  51. current_param_list: PARAM_LIST,
  52. num_averaged: Tensor | int,
  53. ) -> None:
  54. # foreach lerp only handles float and complex
  55. if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex(
  56. averaged_param_list[0]
  57. ):
  58. torch._foreach_lerp_(
  59. averaged_param_list,
  60. current_param_list,
  61. cast(float, 1 / (num_averaged + 1)),
  62. )
  63. else:
  64. diffs = torch._foreach_sub(current_param_list, averaged_param_list)
  65. if isinstance(num_averaged, Tensor):
  66. torch._foreach_addcdiv_(
  67. averaged_param_list,
  68. diffs,
  69. [num_averaged + 1] * len(averaged_param_list),
  70. )
  71. else:
  72. torch._foreach_add_(
  73. averaged_param_list, diffs, alpha=1.0 / (num_averaged + 1)
  74. )
  75. return swa_update
  76. def get_ema_avg_fn(decay=0.999):
  77. """Get the function applying exponential moving average (EMA) across a single param."""
  78. if decay < 0.0 or decay > 1.0:
  79. raise ValueError(
  80. f"Invalid decay value {decay} provided. Please provide a value in [0,1] range."
  81. )
  82. @torch.no_grad()
  83. def ema_update(ema_param: Tensor, current_param: Tensor, num_averaged):
  84. return decay * ema_param + (1 - decay) * current_param
  85. return ema_update
  86. def get_swa_avg_fn():
  87. """Get the function applying stochastic weight average (SWA) across a single param."""
  88. @torch.no_grad()
  89. def swa_update(
  90. averaged_param: Tensor, current_param: Tensor, num_averaged: Tensor | int
  91. ):
  92. return averaged_param + (current_param - averaged_param) / (num_averaged + 1)
  93. return swa_update
  94. class AveragedModel(Module):
  95. r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA).
  96. Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
  97. Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
  98. Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
  99. (UAI 2018).
  100. Exponential Moving Average is a variation of `Polyak averaging`_,
  101. but using exponential weights instead of equal weights across iterations.
  102. AveragedModel class creates a copy of the provided module :attr:`model`
  103. on the device :attr:`device` and allows to compute running averages of the
  104. parameters of the :attr:`model`.
  105. Args:
  106. model (torch.nn.Module): model to use with SWA/EMA
  107. device (torch.device, optional): if provided, the averaged model will be
  108. stored on the :attr:`device`
  109. avg_fn (function, optional): the averaging function used to update
  110. parameters; the function must take in the current value of the
  111. :class:`AveragedModel` parameter, the current value of :attr:`model`
  112. parameter, and the number of models already averaged; if None,
  113. an equally weighted average is used (default: None)
  114. multi_avg_fn (function, optional): the averaging function used to update
  115. parameters inplace; the function must take in the current values of the
  116. :class:`AveragedModel` parameters as a list, the current values of :attr:`model`
  117. parameters as a list, and the number of models already averaged; if None,
  118. an equally weighted average is used (default: None)
  119. use_buffers (bool): if ``True``, it will compute running averages for
  120. both the parameters and the buffers of the model. (default: ``False``)
  121. Example:
  122. >>> # xdoctest: +SKIP("undefined variables")
  123. >>> loader, optimizer, model, loss_fn = ...
  124. >>> swa_model = torch.optim.swa_utils.AveragedModel(model)
  125. >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
  126. >>> T_max=300)
  127. >>> swa_start = 160
  128. >>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
  129. >>> for i in range(300):
  130. >>> for input, target in loader:
  131. >>> optimizer.zero_grad()
  132. >>> loss_fn(model(input), target).backward()
  133. >>> optimizer.step()
  134. >>> if i > swa_start:
  135. >>> swa_model.update_parameters(model)
  136. >>> swa_scheduler.step()
  137. >>> else:
  138. >>> scheduler.step()
  139. >>>
  140. >>> # Update bn statistics for the swa_model at the end
  141. >>> torch.optim.swa_utils.update_bn(loader, swa_model)
  142. You can also use custom averaging functions with the `avg_fn` or `multi_avg_fn` parameters.
  143. If no averaging function is provided, the default is to compute
  144. equally-weighted average of the weights (SWA).
  145. Example:
  146. >>> # xdoctest: +SKIP("undefined variables")
  147. >>> # Compute exponential moving averages of the weights and buffers
  148. >>> ema_model = torch.optim.swa_utils.AveragedModel(model,
  149. >>> torch.optim.swa_utils.get_ema_multi_avg_fn(0.9), use_buffers=True)
  150. .. note::
  151. When using SWA/EMA with models containing Batch Normalization you may
  152. need to update the activation statistics for Batch Normalization.
  153. This can be done either by using the :meth:`torch.optim.swa_utils.update_bn`
  154. or by setting :attr:`use_buffers` to `True`. The first approach updates the
  155. statistics in a post-training step by passing data through the model. The
  156. second does it during the parameter update phase by averaging all buffers.
  157. Empirical evidence has shown that updating the statistics in normalization
  158. layers increases accuracy, but you may wish to empirically test which
  159. approach yields the best results in your problem.
  160. .. note::
  161. :attr:`avg_fn` and `multi_avg_fn` are not saved in the :meth:`state_dict` of the model.
  162. .. note::
  163. When :meth:`update_parameters` is called for the first time (i.e.
  164. :attr:`n_averaged` is `0`) the parameters of `model` are copied
  165. to the parameters of :class:`AveragedModel`. For every subsequent
  166. call of :meth:`update_parameters` the function `avg_fn` is used
  167. to update the parameters.
  168. .. _Averaging Weights Leads to Wider Optima and Better Generalization:
  169. https://arxiv.org/abs/1803.05407
  170. .. _There Are Many Consistent Explanations of Unlabeled Data: Why You Should
  171. Average:
  172. https://arxiv.org/abs/1806.05594
  173. .. _SWALP: Stochastic Weight Averaging in Low-Precision Training:
  174. https://arxiv.org/abs/1904.11943
  175. .. _Stochastic Weight Averaging in Parallel: Large-Batch Training That
  176. Generalizes Well:
  177. https://arxiv.org/abs/2001.02312
  178. .. _Polyak averaging:
  179. https://paperswithcode.com/method/polyak-averaging
  180. """
  181. n_averaged: Tensor
  182. def __init__(
  183. self,
  184. model: Module,
  185. device: int | torch.device | None = None,
  186. avg_fn: Callable[[Tensor, Tensor, Tensor | int], Tensor] | None = None,
  187. multi_avg_fn: Callable[[PARAM_LIST, PARAM_LIST, Tensor | int], None]
  188. | None = None,
  189. use_buffers=False,
  190. ) -> None: # noqa: D107
  191. super().__init__()
  192. if avg_fn is not None and multi_avg_fn is not None:
  193. raise AssertionError(
  194. "Only one of avg_fn and multi_avg_fn should be provided"
  195. )
  196. self.module = deepcopy(model)
  197. if device is not None:
  198. self.module = self.module.to(device)
  199. self.register_buffer(
  200. "n_averaged", torch.tensor(0, dtype=torch.long, device=device)
  201. )
  202. self.avg_fn = avg_fn
  203. self.multi_avg_fn = multi_avg_fn
  204. self.use_buffers = use_buffers
  205. def forward(self, *args, **kwargs):
  206. """Forward pass."""
  207. return self.module(*args, **kwargs)
  208. def update_parameters(self, model: Module) -> None:
  209. """Update model parameters."""
  210. self_param = (
  211. # pyrefly: ignore [bad-argument-type]
  212. itertools.chain(self.module.parameters(), self.module.buffers())
  213. if self.use_buffers
  214. else self.parameters()
  215. )
  216. model_param = (
  217. # pyrefly: ignore [bad-argument-type]
  218. itertools.chain(model.parameters(), model.buffers())
  219. if self.use_buffers
  220. else model.parameters()
  221. )
  222. self_param_detached: list[Tensor | None] = []
  223. model_param_detached: list[Tensor | None] = []
  224. copy_param = bool(self.n_averaged == 0)
  225. for p_averaged, p_model in zip(self_param, model_param, strict=False):
  226. p_model_ = p_model.detach().to(p_averaged.device)
  227. self_param_detached.append(p_averaged.detach())
  228. model_param_detached.append(p_model_)
  229. if copy_param:
  230. p_averaged.detach().copy_(p_model_)
  231. if self.n_averaged > 0:
  232. if self.multi_avg_fn is not None or self.avg_fn is None:
  233. grouped_tensors = _group_tensors_by_device_and_dtype(
  234. [self_param_detached, model_param_detached]
  235. )
  236. for (device, _), (
  237. [self_params, model_params],
  238. _,
  239. ) in grouped_tensors.items():
  240. if self.multi_avg_fn:
  241. self.multi_avg_fn(
  242. self_params, # type: ignore[arg-type]
  243. model_params, # type: ignore[arg-type]
  244. self.n_averaged.to(device),
  245. )
  246. elif (
  247. device is not None
  248. and device.type in _get_foreach_kernels_supported_devices()
  249. ):
  250. multi_avg_fn = get_swa_multi_avg_fn()
  251. multi_avg_fn(
  252. self_params, model_params, self.n_averaged.to(device)
  253. )
  254. else:
  255. avg_fn = get_swa_avg_fn()
  256. n_averaged = self.n_averaged.to(device)
  257. for p_averaged, p_model in zip( # type: ignore[assignment]
  258. self_params, model_params, strict=True
  259. ):
  260. # pyrefly: ignore [missing-attribute]
  261. p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
  262. else:
  263. for p_averaged, p_model in zip( # type: ignore[assignment]
  264. self_param_detached, model_param_detached, strict=True
  265. ):
  266. # pyrefly: ignore [missing-attribute]
  267. n_averaged = self.n_averaged.to(p_averaged.device)
  268. # pyrefly: ignore [missing-attribute]
  269. p_averaged.detach().copy_(
  270. # pyrefly: ignore [missing-attribute, bad-argument-type]
  271. self.avg_fn(p_averaged.detach(), p_model, n_averaged)
  272. )
  273. if not self.use_buffers:
  274. # If not apply running averages to the buffers,
  275. # keep the buffers in sync with the source model.
  276. for b_swa, b_model in zip(
  277. self.module.buffers(), model.buffers(), strict=True
  278. ):
  279. b_swa.detach().copy_(b_model.detach().to(b_swa.device))
  280. self.n_averaged += 1
  281. @torch.no_grad()
  282. def update_bn(
  283. loader: Iterable[Any],
  284. model: Module,
  285. device: int | torch.device | None = None,
  286. ) -> None:
  287. r"""Update BatchNorm running_mean, running_var buffers in the model.
  288. It performs one pass over data in `loader` to estimate the activation
  289. statistics for BatchNorm layers in the model.
  290. Args:
  291. loader (torch.utils.data.DataLoader): dataset loader to compute the
  292. activation statistics on. Each data batch should be either a
  293. tensor, or a list/tuple whose first element is a tensor
  294. containing data.
  295. model (torch.nn.Module): model for which we seek to update BatchNorm
  296. statistics.
  297. device (torch.device, optional): If set, data will be transferred to
  298. :attr:`device` before being passed into :attr:`model`.
  299. Example:
  300. >>> # xdoctest: +SKIP("Undefined variables")
  301. >>> loader, model = ...
  302. >>> torch.optim.swa_utils.update_bn(loader, model)
  303. .. note::
  304. The `update_bn` utility assumes that each data batch in :attr:`loader`
  305. is either a tensor or a list or tuple of tensors; in the latter case it
  306. is assumed that :meth:`model.forward()` should be called on the first
  307. element of the list or tuple corresponding to the data batch.
  308. """
  309. momenta = {}
  310. for module in model.modules():
  311. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  312. module.reset_running_stats()
  313. momenta[module] = module.momentum
  314. if not momenta:
  315. return
  316. was_training = model.training
  317. model.train()
  318. for module in momenta:
  319. module.momentum = None
  320. for input in loader:
  321. if isinstance(input, (list, tuple)):
  322. input = input[0]
  323. if device is not None:
  324. input = input.to(device)
  325. model(input)
  326. for bn_module in momenta:
  327. bn_module.momentum = momenta[bn_module]
  328. model.train(was_training)
  329. class SWALR(LRScheduler):
  330. r"""Anneals the learning rate in each parameter group to a fixed value.
  331. This learning rate scheduler is meant to be used with Stochastic Weight
  332. Averaging (SWA) method (see `torch.optim.swa_utils.AveragedModel`).
  333. Args:
  334. optimizer (torch.optim.Optimizer): wrapped optimizer
  335. swa_lrs (float or list): the learning rate value for all param groups
  336. together or separately for each group.
  337. annealing_epochs (int): number of epochs in the annealing phase
  338. (default: 10)
  339. annealing_strategy (str): "cos" or "linear"; specifies the annealing
  340. strategy: "cos" for cosine annealing, "linear" for linear annealing
  341. (default: "cos")
  342. last_epoch (int): the index of the last epoch (default: -1)
  343. The :class:`SWALR` scheduler can be used together with other
  344. schedulers to switch to a constant learning rate late in the training
  345. as in the example below.
  346. Example:
  347. >>> # xdoctest: +SKIP("Undefined variables")
  348. >>> loader, optimizer, model = ...
  349. >>> lr_lambda = lambda epoch: 0.9
  350. >>> scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,
  351. >>> lr_lambda=lr_lambda)
  352. >>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer,
  353. >>> anneal_strategy="linear", anneal_epochs=20, swa_lr=0.05)
  354. >>> swa_start = 160
  355. >>> for i in range(300):
  356. >>> for input, target in loader:
  357. >>> optimizer.zero_grad()
  358. >>> loss_fn(model(input), target).backward()
  359. >>> optimizer.step()
  360. >>> if i > swa_start:
  361. >>> swa_scheduler.step()
  362. >>> else:
  363. >>> scheduler.step()
  364. .. _Averaging Weights Leads to Wider Optima and Better Generalization:
  365. https://arxiv.org/abs/1803.05407
  366. """
  367. def __init__(
  368. self,
  369. optimizer: Optimizer,
  370. swa_lr: float,
  371. anneal_epochs=10,
  372. anneal_strategy: Literal["cos", "linear"] = "cos",
  373. last_epoch=-1,
  374. ) -> None: # noqa: D107
  375. swa_lrs = _format_param("swa_lr", optimizer, swa_lr)
  376. for swa_lr, group in zip(swa_lrs, optimizer.param_groups, strict=True):
  377. group["swa_lr"] = swa_lr
  378. if anneal_strategy not in ["cos", "linear"]:
  379. raise ValueError(
  380. "anneal_strategy must by one of 'cos' or 'linear', "
  381. f"instead got {anneal_strategy}"
  382. )
  383. self._set_anneal_func(anneal_strategy)
  384. if not isinstance(anneal_epochs, int) or anneal_epochs < 0:
  385. raise ValueError(
  386. f"anneal_epochs must be equal or greater than 0, got {anneal_epochs}"
  387. )
  388. self.anneal_epochs = anneal_epochs
  389. super().__init__(optimizer, last_epoch)
  390. @staticmethod
  391. def _linear_anneal(t):
  392. return t
  393. @staticmethod
  394. def _cosine_anneal(t):
  395. return (1 - math.cos(math.pi * t)) / 2
  396. @staticmethod
  397. def _get_initial_lr(lr, swa_lr, alpha):
  398. if alpha == 1:
  399. return swa_lr
  400. return (lr - alpha * swa_lr) / (1 - alpha)
  401. @override
  402. def get_lr(self):
  403. r"""Compute the next learning rate for each of the optimizer's
  404. :attr:`~torch.optim.Optimizer.param_groups`.
  405. Uses :attr:`anneal_func` to interpolate between each group's
  406. ``group["lr"]`` and ``group["swa_lr"]`` over :attr:`anneal_epochs`
  407. epochs. Once :attr:`anneal_epochs` is reached, keeps the learning rate
  408. fixed at ``group["swa_lr"]``.
  409. Returns:
  410. list[float | Tensor]: A :class:`list` of learning rates for each of
  411. the optimizer's :attr:`~torch.optim.Optimizer.param_groups` with the
  412. same types as their current ``group["lr"]``\s.
  413. .. note::
  414. If you're trying to inspect the most recent learning rate, use
  415. :meth:`get_last_lr()` instead.
  416. .. note::
  417. The returned :class:`~torch.Tensor`\s are copies, and never alias
  418. the optimizer's ``group["lr"]``\s.
  419. """
  420. # `_get_lr_called_within_step` is only available `_enable_get_lr_call`,
  421. # so we ignore the type error here. See `LRScheduler.step()` for more details.
  422. if not self._get_lr_called_within_step:
  423. warnings.warn(
  424. "To get the last learning rate computed by the scheduler, "
  425. "please use `get_last_lr()`.",
  426. UserWarning,
  427. stacklevel=2,
  428. )
  429. # Set in `LRScheduler._initial_step()`
  430. step = self._step_count - 1
  431. if self.anneal_epochs == 0:
  432. step = max(1, step)
  433. # pyrefly: ignore [no-matching-overload]
  434. prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
  435. prev_alpha = self.anneal_func(prev_t)
  436. prev_lrs = [
  437. self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha)
  438. for group in self.optimizer.param_groups
  439. ]
  440. # pyrefly: ignore [no-matching-overload]
  441. t = max(0, min(1, step / max(1, self.anneal_epochs)))
  442. alpha = self.anneal_func(t)
  443. return [
  444. group["swa_lr"] * alpha + lr * (1 - alpha)
  445. for group, lr in zip(self.optimizer.param_groups, prev_lrs, strict=True)
  446. ]
  447. def _set_anneal_func(self, anneal_strategy: Literal["cos", "linear"]) -> None:
  448. self._anneal_strategy = anneal_strategy
  449. if anneal_strategy == "cos":
  450. self.anneal_func = self._cosine_anneal
  451. else:
  452. self.anneal_func = self._linear_anneal
  453. @override
  454. def state_dict(self) -> dict[str, Any]:
  455. """Return the state of the scheduler as a :class:`dict`.
  456. It contains an entry for every variable in self.__dict__ which
  457. is not the optimizer or anneal_func.
  458. """
  459. return {
  460. key: value
  461. for key, value in self.__dict__.items()
  462. if key not in ("optimizer", "anneal_func")
  463. }
  464. @override
  465. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  466. """Load the scheduler's state.
  467. Args:
  468. state_dict (dict): scheduler state. Should be an object returned
  469. from a call to :meth:`state_dict`.
  470. """
  471. self.__dict__.update(state_dict)
  472. self._set_anneal_func(self._anneal_strategy)