lr_scheduler.py 83 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172
  1. # mypy: allow-untyped-defs
  2. r"""Learning Rate Scheduler."""
  3. from __future__ import annotations
  4. import math
  5. import types
  6. import warnings
  7. from bisect import bisect_right
  8. from collections import Counter
  9. from functools import partial, wraps
  10. from typing import (
  11. Any,
  12. Callable,
  13. cast,
  14. Literal,
  15. Optional,
  16. SupportsFloat,
  17. TYPE_CHECKING,
  18. TypedDict,
  19. Union,
  20. )
  21. from typing_extensions import override, Self
  22. from weakref import ref
  23. from torch import inf, Tensor
  24. from .optimizer import _to_scalar, Optimizer
  25. if TYPE_CHECKING:
  26. from collections.abc import Iterable, Sequence
  27. __all__ = [
  28. "LambdaLR",
  29. "MultiplicativeLR",
  30. "StepLR",
  31. "MultiStepLR",
  32. "ConstantLR",
  33. "LinearLR",
  34. "ExponentialLR",
  35. "SequentialLR",
  36. "CosineAnnealingLR",
  37. "ChainedScheduler",
  38. "ReduceLROnPlateau",
  39. "CyclicLR",
  40. "CosineAnnealingWarmRestarts",
  41. "OneCycleLR",
  42. "PolynomialLR",
  43. "LRScheduler",
  44. ]
  45. EPOCH_DEPRECATION_WARNING = (
  46. "The epoch parameter in `scheduler.step()` was not necessary and is being "
  47. "deprecated where possible. Please use `scheduler.step()` to step the "
  48. "scheduler. During the deprecation, if epoch is different from None, the "
  49. "closed form is used instead of the new chainable form, where available. "
  50. "Please open an issue if you are unable to replicate your use case: "
  51. "https://github.com/pytorch/pytorch/issues/new/choose."
  52. )
  53. def _format_param(name: str, optimizer: Optimizer, param):
  54. """Return correctly formatted lr/momentum for each param group."""
  55. def _copy(_param):
  56. return _param.clone() if isinstance(_param, Tensor) else _param
  57. if isinstance(param, (list, tuple)):
  58. if len(param) != len(optimizer.param_groups):
  59. raise ValueError(
  60. f"{name} must have the same length as optimizer.param_groups. "
  61. f"{name} has {len(param)} values, param_groups has {len(optimizer.param_groups)}."
  62. )
  63. else:
  64. param = [param] * len(optimizer.param_groups)
  65. return list(map(_copy, param))
  66. class LRScheduler:
  67. r"""Adjusts the learning rate during optimization."""
  68. _get_lr_called_within_step: bool = False
  69. _is_initial: bool = False
  70. def __init__(
  71. self,
  72. optimizer: Optimizer,
  73. last_epoch: int = -1,
  74. ) -> None: # noqa: D107
  75. # Attach optimizer
  76. if not isinstance(optimizer, Optimizer):
  77. raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
  78. self.optimizer = optimizer
  79. # Initialize epoch and base learning rates
  80. if last_epoch == -1:
  81. for group in optimizer.param_groups:
  82. initial_lr = group["lr"]
  83. if isinstance(initial_lr, Tensor):
  84. initial_lr = initial_lr.clone()
  85. group.setdefault("initial_lr", initial_lr)
  86. else:
  87. for i, group in enumerate(optimizer.param_groups):
  88. if "initial_lr" not in group:
  89. raise KeyError(
  90. "param 'initial_lr' is not specified "
  91. f"in param_groups[{i}] when resuming an optimizer"
  92. )
  93. self.base_lrs: list[float] = [
  94. group["initial_lr"] for group in optimizer.param_groups
  95. ]
  96. self.last_epoch = last_epoch
  97. # Following https://github.com/pytorch/pytorch/issues/20124
  98. # We would like to ensure that `lr_scheduler.step()` is called after
  99. # `optimizer.step()`
  100. def patch_track_step_called(opt: Optimizer):
  101. if hasattr(opt.step, "_wrapped_by_lr_sched"):
  102. # we've already patched
  103. return opt.step
  104. def wrap_step(step_fn):
  105. opt_ref = ref(self.optimizer)
  106. func = step_fn.__func__
  107. @wraps(func)
  108. def wrapper(*args, **kwargs):
  109. opt = opt_ref()
  110. opt._opt_called = True # type: ignore[union-attr]
  111. return func.__get__(opt, opt.__class__)(*args, **kwargs)
  112. wrapper._wrapped_by_lr_sched = True # type: ignore[attr-defined]
  113. return wrapper
  114. opt.step = wrap_step(opt.step) # type: ignore[method-assign]
  115. patch_track_step_called(self.optimizer)
  116. self._initial_step()
  117. def _initial_step(self) -> None:
  118. """Initialize step counts and perform a step."""
  119. self._step_count = 0
  120. with _initial_mode(self):
  121. self.step()
  122. def state_dict(self) -> dict[str, Any]:
  123. """Return the state of the scheduler as a :class:`dict`.
  124. It contains an entry for every variable in self.__dict__ which
  125. is not the optimizer.
  126. """
  127. return {
  128. key: value for key, value in self.__dict__.items() if key != "optimizer"
  129. }
  130. def load_state_dict(self, state_dict: dict[str, Any]):
  131. """Load the scheduler's state.
  132. Args:
  133. state_dict (dict): scheduler state. Should be an object returned
  134. from a call to :meth:`state_dict`.
  135. """
  136. self.__dict__.update(state_dict)
  137. def get_last_lr(self) -> list[float]:
  138. """Return last computed learning rate by current scheduler."""
  139. return self._last_lr
  140. def get_lr(self) -> list[float]:
  141. """Compute learning rate using chainable form of the scheduler."""
  142. raise NotImplementedError
  143. def step(self, epoch: Optional[int] = None) -> None:
  144. """Perform a step."""
  145. # Raise a warning if old pattern is detected
  146. # https://github.com/pytorch/pytorch/issues/20124
  147. if self._step_count == 1:
  148. if not hasattr(self.optimizer.step, "_wrapped_by_lr_sched"):
  149. warnings.warn(
  150. "Seems like `optimizer.step()` has been overridden after learning rate scheduler "
  151. "initialization. Please, make sure to call `optimizer.step()` before "
  152. "`lr_scheduler.step()`. See more details at "
  153. "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
  154. UserWarning,
  155. )
  156. # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
  157. elif not getattr(self.optimizer, "_opt_called", False):
  158. warnings.warn(
  159. "Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
  160. "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
  161. "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this "
  162. "will result in PyTorch skipping the first value of the learning rate schedule. "
  163. "See more details at "
  164. "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate",
  165. UserWarning,
  166. )
  167. self._step_count += 1
  168. if epoch is not None:
  169. warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
  170. self._update_lr(epoch)
  171. def _update_lr(self, epoch: Optional[int] = None):
  172. with _enable_get_lr_call(self):
  173. if epoch is None:
  174. self.last_epoch += 1
  175. values = self.get_lr()
  176. else:
  177. self.last_epoch = epoch
  178. if hasattr(self, "_get_closed_form_lr"):
  179. values = cast(list[float], self._get_closed_form_lr())
  180. else:
  181. values = self.get_lr()
  182. for param_group, lr in zip(self.optimizer.param_groups, values):
  183. if isinstance(param_group["lr"], Tensor):
  184. param_group["lr"].fill_(_to_scalar(lr))
  185. else:
  186. param_group["lr"] = lr
  187. self._last_lr: list[float] = [
  188. group["lr"] for group in self.optimizer.param_groups
  189. ]
  190. def _warn_get_lr_called_within_step(lr_scheduler: LRScheduler) -> None:
  191. if not lr_scheduler._get_lr_called_within_step:
  192. warnings.warn(
  193. "To get the last learning rate computed by the scheduler, "
  194. "please use `get_last_lr()`.",
  195. UserWarning,
  196. stacklevel=2,
  197. )
  198. # Including _LRScheduler for backwards compatibility
  199. # Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler).
  200. class _LRScheduler(LRScheduler):
  201. pass
  202. class _enable_get_lr_call:
  203. def __init__(self, o: LRScheduler) -> None:
  204. self.o = o
  205. def __enter__(self) -> Self:
  206. self.o._get_lr_called_within_step = True
  207. return self
  208. def __exit__(self, type, value, traceback) -> None:
  209. self.o._get_lr_called_within_step = False
  210. class _initial_mode:
  211. def __init__(self, o: LRScheduler):
  212. self.o = o
  213. def __enter__(self):
  214. self.o._is_initial = True
  215. def __exit__(self, type, value, traceback):
  216. self.o._is_initial = False
  217. class LambdaLR(LRScheduler):
  218. """Sets the initial learning rate.
  219. The learning rate of each parameter group is set to the initial lr
  220. times a given function. When last_epoch=-1, sets initial lr as lr.
  221. Args:
  222. optimizer (Optimizer): Wrapped optimizer.
  223. lr_lambda (function or list): A function which computes a multiplicative
  224. factor given an integer parameter epoch, or a list of such
  225. functions, one for each group in optimizer.param_groups.
  226. last_epoch (int): The index of last epoch. Default: -1.
  227. Example:
  228. >>> # xdoctest: +SKIP
  229. >>> # Assuming optimizer has two groups.
  230. >>> num_epochs = 100
  231. >>> lambda1 = lambda epoch: epoch // 30
  232. >>> lambda2 = lambda epoch: 0.95**epoch
  233. >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
  234. >>> for epoch in range(num_epochs):
  235. >>> train(...)
  236. >>> validate(...)
  237. >>> scheduler.step()
  238. >>>
  239. >>> # Alternatively, you can use a single lambda function for all groups.
  240. >>> scheduler = LambdaLR(opt, lr_lambda=lambda epoch: epoch // 30)
  241. >>> for epoch in range(num_epochs):
  242. >>> train(...)
  243. >>> validate(...)
  244. >>> scheduler.step()
  245. .. image:: ../scripts/lr_scheduler_images/LambdaLR.png
  246. """
  247. def __init__(
  248. self,
  249. optimizer: Optimizer,
  250. lr_lambda: Union[Callable[[int], float], list[Callable[[int], float]]],
  251. last_epoch: int = -1,
  252. ) -> None: # noqa: D107
  253. self.optimizer = optimizer
  254. self.lr_lambdas: list[Callable[[int], float]]
  255. if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
  256. self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
  257. else:
  258. if len(lr_lambda) != len(optimizer.param_groups):
  259. raise ValueError(
  260. f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}"
  261. )
  262. self.lr_lambdas = list(lr_lambda)
  263. super().__init__(optimizer, last_epoch)
  264. @override
  265. def state_dict(self) -> dict[str, Any]:
  266. """Return the state of the scheduler as a :class:`dict`.
  267. It contains an entry for every variable in self.__dict__ which
  268. is not the optimizer.
  269. The learning rate lambda functions will only be saved if they are callable objects
  270. and not if they are functions or lambdas.
  271. When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
  272. """
  273. state_dict = {
  274. key: value
  275. for key, value in self.__dict__.items()
  276. if key not in ("optimizer", "lr_lambdas")
  277. }
  278. state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas)
  279. for idx, fn in enumerate(self.lr_lambdas):
  280. if not isinstance(fn, types.FunctionType):
  281. state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
  282. return state_dict
  283. @override
  284. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  285. """Load the scheduler's state.
  286. When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
  287. Args:
  288. state_dict (dict): scheduler state. Should be an object returned
  289. from a call to :meth:`state_dict`.
  290. """
  291. lr_lambdas = state_dict.pop("lr_lambdas")
  292. self.__dict__.update(state_dict)
  293. # Restore state_dict keys in order to prevent side effects
  294. # https://github.com/pytorch/pytorch/issues/32756
  295. state_dict["lr_lambdas"] = lr_lambdas
  296. for idx, fn in enumerate(lr_lambdas):
  297. if fn is not None:
  298. self.lr_lambdas[idx].__dict__.update(fn)
  299. @override
  300. def get_lr(self) -> list[float]:
  301. """Compute learning rate."""
  302. _warn_get_lr_called_within_step(self)
  303. return [
  304. base_lr * lmbda(self.last_epoch)
  305. for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)
  306. ]
  307. class MultiplicativeLR(LRScheduler):
  308. """Multiply the learning rate of each parameter group by the factor given in the specified function.
  309. When last_epoch=-1, set initial lr as lr.
  310. Args:
  311. optimizer (Optimizer): Wrapped optimizer.
  312. lr_lambda (function or list): A function which computes a multiplicative
  313. factor given an integer parameter epoch, or a list of such
  314. functions, one for each group in optimizer.param_groups.
  315. last_epoch (int): The index of last epoch. Default: -1.
  316. Example:
  317. >>> # xdoctest: +SKIP
  318. >>> lmbda = lambda epoch: 0.95
  319. >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)
  320. >>> for epoch in range(100):
  321. >>> train(...)
  322. >>> validate(...)
  323. >>> scheduler.step()
  324. .. image:: ../scripts/lr_scheduler_images/MultiplicativeLR.png
  325. """
  326. def __init__(
  327. self,
  328. optimizer: Optimizer,
  329. lr_lambda: Union[Callable[[int], float], list[Callable[[int], float]]],
  330. last_epoch: int = -1,
  331. ) -> None: # noqa: D107
  332. self.optimizer = optimizer
  333. self.lr_lambdas: list[Callable[[int], float]]
  334. if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
  335. self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
  336. else:
  337. if len(lr_lambda) != len(optimizer.param_groups):
  338. raise ValueError(
  339. f"Expected {len(optimizer.param_groups)} lr_lambdas, but got {len(lr_lambda)}"
  340. )
  341. self.lr_lambdas = list(lr_lambda)
  342. for lr_lambda in self.lr_lambdas:
  343. if not callable(lr_lambda):
  344. raise TypeError(
  345. f"lr_lambda should be a function, but got {type(lr_lambda).__name__}"
  346. )
  347. super().__init__(optimizer, last_epoch)
  348. @override
  349. def state_dict(self) -> dict[str, Any]:
  350. """Return the state of the scheduler as a :class:`dict`.
  351. It contains an entry for every variable in self.__dict__ which
  352. is not the optimizer.
  353. The learning rate lambda functions will only be saved if they are callable objects
  354. and not if they are functions or lambdas.
  355. """
  356. state_dict = {
  357. key: value
  358. for key, value in self.__dict__.items()
  359. if key not in ("optimizer", "lr_lambdas")
  360. }
  361. state_dict["lr_lambdas"] = [None] * len(self.lr_lambdas)
  362. for idx, fn in enumerate(self.lr_lambdas):
  363. if not isinstance(fn, types.FunctionType):
  364. state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
  365. return state_dict
  366. @override
  367. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  368. """Load the scheduler's state.
  369. Args:
  370. state_dict (dict): scheduler state. Should be an object returned
  371. from a call to :meth:`state_dict`.
  372. """
  373. lr_lambdas = state_dict.pop("lr_lambdas")
  374. self.__dict__.update(state_dict)
  375. # Restore state_dict keys in order to prevent side effects
  376. # https://github.com/pytorch/pytorch/issues/32756
  377. state_dict["lr_lambdas"] = lr_lambdas
  378. for idx, fn in enumerate(lr_lambdas):
  379. if fn is not None:
  380. self.lr_lambdas[idx].__dict__.update(fn)
  381. @override
  382. def get_lr(self) -> list[float]:
  383. """Compute the learning rate of each parameter group."""
  384. _warn_get_lr_called_within_step(self)
  385. if not self._is_initial:
  386. return [
  387. group["lr"] * lmbda(self.last_epoch)
  388. for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)
  389. ]
  390. else:
  391. return [group["lr"] for group in self.optimizer.param_groups]
  392. class StepLR(LRScheduler):
  393. """Decays the learning rate of each parameter group by gamma every step_size epochs.
  394. Notice that such decay can happen simultaneously with other changes to the learning rate
  395. from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
  396. Args:
  397. optimizer (Optimizer): Wrapped optimizer.
  398. step_size (int): Period of learning rate decay.
  399. gamma (float): Multiplicative factor of learning rate decay.
  400. Default: 0.1.
  401. last_epoch (int): The index of last epoch. Default: -1.
  402. Example:
  403. >>> # xdoctest: +SKIP
  404. >>> # Assuming optimizer uses lr = 0.05 for all groups
  405. >>> # lr = 0.05 if epoch < 30
  406. >>> # lr = 0.005 if 30 <= epoch < 60
  407. >>> # lr = 0.0005 if 60 <= epoch < 90
  408. >>> # ...
  409. >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
  410. >>> for epoch in range(100):
  411. >>> train(...)
  412. >>> validate(...)
  413. >>> scheduler.step()
  414. .. image:: ../scripts/lr_scheduler_images/StepLR.png
  415. """
  416. def __init__(
  417. self,
  418. optimizer: Optimizer,
  419. step_size: int,
  420. gamma: float = 0.1,
  421. last_epoch: int = -1,
  422. ) -> None: # noqa: D107
  423. self.step_size = step_size
  424. self.gamma = gamma
  425. super().__init__(optimizer, last_epoch)
  426. @override
  427. def get_lr(self) -> list[float]:
  428. """Compute the learning rate of each parameter group."""
  429. _warn_get_lr_called_within_step(self)
  430. if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
  431. return [group["lr"] for group in self.optimizer.param_groups]
  432. return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
  433. def _get_closed_form_lr(self) -> list[float]:
  434. return [
  435. base_lr * self.gamma ** (self.last_epoch // self.step_size)
  436. for base_lr in self.base_lrs
  437. ]
  438. class MultiStepLR(LRScheduler):
  439. """Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones.
  440. Notice that such decay can happen simultaneously with other changes to the learning rate
  441. from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
  442. Args:
  443. optimizer (Optimizer): Wrapped optimizer.
  444. milestones (list): List of epoch indices. Must be increasing.
  445. gamma (float): Multiplicative factor of learning rate decay.
  446. Default: 0.1.
  447. last_epoch (int): The index of last epoch. Default: -1.
  448. Example:
  449. >>> # xdoctest: +SKIP
  450. >>> # Assuming optimizer uses lr = 0.05 for all groups
  451. >>> # lr = 0.05 if epoch < 30
  452. >>> # lr = 0.005 if 30 <= epoch < 80
  453. >>> # lr = 0.0005 if epoch >= 80
  454. >>> scheduler = MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)
  455. >>> for epoch in range(100):
  456. >>> train(...)
  457. >>> validate(...)
  458. >>> scheduler.step()
  459. .. image:: ../scripts/lr_scheduler_images/MultiStepLR.png
  460. """
  461. def __init__(
  462. self,
  463. optimizer: Optimizer,
  464. milestones: Iterable[int],
  465. gamma: float = 0.1,
  466. last_epoch: int = -1,
  467. ) -> None: # noqa: D107
  468. self.milestones = Counter(milestones)
  469. self.gamma = gamma
  470. super().__init__(optimizer, last_epoch)
  471. @override
  472. def get_lr(self) -> list[float]:
  473. """Compute the learning rate of each parameter group."""
  474. _warn_get_lr_called_within_step(self)
  475. if self.last_epoch not in self.milestones:
  476. return [group["lr"] for group in self.optimizer.param_groups]
  477. return [
  478. group["lr"] * self.gamma ** self.milestones[self.last_epoch]
  479. for group in self.optimizer.param_groups
  480. ]
  481. def _get_closed_form_lr(self):
  482. milestones = sorted(self.milestones.elements())
  483. return [
  484. base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
  485. for base_lr in self.base_lrs
  486. ]
  487. class ConstantLR(LRScheduler):
  488. """Multiply the learning rate of each parameter group by a small constant factor.
  489. The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters.
  490. Notice that such multiplication of the small constant factor can
  491. happen simultaneously with other changes to the learning rate from outside this scheduler.
  492. When last_epoch=-1, sets initial lr as lr.
  493. Args:
  494. optimizer (Optimizer): Wrapped optimizer.
  495. factor (float): The number we multiply learning rate until the milestone. Default: 1./3.
  496. total_iters (int): The number of steps that the scheduler multiplies the learning rate by the factor.
  497. Default: 5.
  498. last_epoch (int): The index of the last epoch. Default: -1.
  499. Example:
  500. >>> # xdoctest: +SKIP
  501. >>> # Assuming optimizer uses lr = 0.05 for all groups
  502. >>> # lr = 0.025 if epoch == 0
  503. >>> # lr = 0.025 if epoch == 1
  504. >>> # lr = 0.025 if epoch == 2
  505. >>> # lr = 0.025 if epoch == 3
  506. >>> # ...
  507. >>> # lr = 0.05 if epoch >= 40
  508. >>> scheduler = ConstantLR(optimizer, factor=0.5, total_iters=40)
  509. >>> for epoch in range(100):
  510. >>> train(...)
  511. >>> validate(...)
  512. >>> scheduler.step()
  513. .. image:: ../scripts/lr_scheduler_images/ConstantLR.png
  514. """
  515. def __init__(
  516. self,
  517. optimizer: Optimizer,
  518. factor: float = 1.0 / 3,
  519. total_iters: int = 5,
  520. last_epoch: int = -1,
  521. ) -> None: # noqa: D107
  522. if factor > 1.0 or factor < 0:
  523. raise ValueError(
  524. "Constant multiplicative factor expected to be between 0 and 1."
  525. )
  526. self.factor = factor
  527. self.total_iters = total_iters
  528. super().__init__(optimizer, last_epoch)
  529. @override
  530. def get_lr(self) -> list[float]:
  531. """Compute the learning rate of each parameter group."""
  532. _warn_get_lr_called_within_step(self)
  533. if self.last_epoch == 0:
  534. return [group["lr"] * self.factor for group in self.optimizer.param_groups]
  535. if self.last_epoch != self.total_iters:
  536. return [group["lr"] for group in self.optimizer.param_groups]
  537. return [
  538. group["lr"] * (1.0 / self.factor) for group in self.optimizer.param_groups
  539. ]
  540. def _get_closed_form_lr(self):
  541. return [
  542. base_lr
  543. * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
  544. for base_lr in self.base_lrs
  545. ]
  546. class LinearLR(LRScheduler):
  547. """Decays the learning rate of each parameter group by linearly changing small multiplicative factor.
  548. The multiplication is done until the number of epoch reaches a pre-defined milestone: total_iters.
  549. Notice that such decay can happen simultaneously with other changes to the learning rate
  550. from outside this scheduler. When last_epoch=-1, sets initial lr as lr.
  551. Args:
  552. optimizer (Optimizer): Wrapped optimizer.
  553. start_factor (float): The number we multiply learning rate in the first epoch.
  554. The multiplication factor changes towards end_factor in the following epochs.
  555. Default: 1./3.
  556. end_factor (float): The number we multiply learning rate at the end of linear changing
  557. process. Default: 1.0.
  558. total_iters (int): The number of iterations that multiplicative factor reaches to 1.
  559. Default: 5.
  560. last_epoch (int): The index of the last epoch. Default: -1.
  561. Example:
  562. >>> # xdoctest: +SKIP
  563. >>> # Assuming optimizer uses lr = 0.05 for all groups
  564. >>> # lr = 0.003687 if epoch == 0
  565. >>> # lr = 0.004875 if epoch == 1
  566. >>> # lr = 0.006062 if epoch == 2
  567. >>> # lr = 0.00725 if epoch == 3
  568. >>> # ...
  569. >>> # lr = 0.05 if epoch >= 40
  570. >>> scheduler = LinearLR(optimizer, start_factor=0.05, total_iters=40)
  571. >>> for epoch in range(100):
  572. >>> train(...)
  573. >>> validate(...)
  574. >>> scheduler.step()
  575. .. image:: ../scripts/lr_scheduler_images/LinearLR.png
  576. """
  577. def __init__(
  578. self,
  579. optimizer: Optimizer,
  580. start_factor: float = 1.0 / 3,
  581. end_factor: float = 1.0,
  582. total_iters: int = 5,
  583. last_epoch: int = -1,
  584. ) -> None: # noqa: D107
  585. if start_factor > 1.0 or start_factor <= 0:
  586. raise ValueError(
  587. "Starting multiplicative factor expected to be greater than 0 and less or equal to 1."
  588. )
  589. if end_factor > 1.0 or end_factor < 0:
  590. raise ValueError(
  591. "Ending multiplicative factor expected to be between 0 and 1."
  592. )
  593. self.start_factor = start_factor
  594. self.end_factor = end_factor
  595. self.total_iters = total_iters
  596. super().__init__(optimizer, last_epoch)
  597. @override
  598. def get_lr(self) -> list[float]:
  599. """Compute the learning rate."""
  600. _warn_get_lr_called_within_step(self)
  601. if self.last_epoch == 0:
  602. return [
  603. group["lr"] * self.start_factor for group in self.optimizer.param_groups
  604. ]
  605. if self._is_initial or self.last_epoch > self.total_iters:
  606. return [group["lr"] for group in self.optimizer.param_groups]
  607. return [
  608. group["lr"]
  609. * (
  610. 1.0
  611. + (self.end_factor - self.start_factor)
  612. / (
  613. self.total_iters * self.start_factor
  614. + (self.last_epoch - 1) * (self.end_factor - self.start_factor)
  615. )
  616. )
  617. for group in self.optimizer.param_groups
  618. ]
  619. def _get_closed_form_lr(self):
  620. return [
  621. base_lr
  622. * (
  623. self.start_factor
  624. + (self.end_factor - self.start_factor)
  625. * min(self.total_iters, self.last_epoch)
  626. / self.total_iters
  627. )
  628. for base_lr in self.base_lrs
  629. ]
  630. class ExponentialLR(LRScheduler):
  631. """Decays the learning rate of each parameter group by gamma every epoch.
  632. When last_epoch=-1, sets initial lr as lr.
  633. Args:
  634. optimizer (Optimizer): Wrapped optimizer.
  635. gamma (float): Multiplicative factor of learning rate decay.
  636. last_epoch (int): The index of last epoch. Default: -1.
  637. Example:
  638. >>> # xdoctest: +SKIP
  639. >>> scheduler = ExponentialLR(optimizer, gamma=0.95)
  640. >>> for epoch in range(100):
  641. >>> train(...)
  642. >>> validate(...)
  643. >>> scheduler.step()
  644. .. image:: ../scripts/lr_scheduler_images/ExponentialLR.png
  645. """
  646. def __init__(
  647. self,
  648. optimizer: Optimizer,
  649. gamma: float,
  650. last_epoch: int = -1,
  651. ) -> None: # noqa: D107
  652. self.gamma = gamma
  653. super().__init__(optimizer, last_epoch)
  654. @override
  655. def get_lr(self) -> list[float]:
  656. """Compute the learning rate of each parameter group."""
  657. _warn_get_lr_called_within_step(self)
  658. # when loading from a checkpoint, we don't want _initial_step (called from the constructor)
  659. # to update the lr one more step ahead of itself.
  660. if self._is_initial:
  661. return [group["lr"] for group in self.optimizer.param_groups]
  662. return [group["lr"] * self.gamma for group in self.optimizer.param_groups]
  663. def _get_closed_form_lr(self):
  664. return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs]
  665. class SequentialLR(LRScheduler):
  666. """Contains a list of schedulers expected to be called sequentially during the optimization process.
  667. Specifically, the schedulers will be called according to the milestone points, which should provide exact
  668. intervals by which each scheduler should be called at a given epoch.
  669. Args:
  670. optimizer (Optimizer): Wrapped optimizer.
  671. schedulers (list): List of chained schedulers.
  672. milestones (list): List of integers that reflects milestone points.
  673. last_epoch (int): The index of last epoch. Default: -1.
  674. Example:
  675. >>> # xdoctest: +SKIP
  676. >>> # Assuming optimizer uses lr = 0.05 for all groups
  677. >>> # lr = 0.005 if epoch == 0
  678. >>> # lr = 0.005 if epoch == 1
  679. >>> # lr = 0.005 if epoch == 2
  680. >>> # ...
  681. >>> # lr = 0.05 if epoch == 20
  682. >>> # lr = 0.045 if epoch == 21
  683. >>> # lr = 0.0405 if epoch == 22
  684. >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=20)
  685. >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
  686. >>> scheduler = SequentialLR(
  687. ... optimizer,
  688. ... schedulers=[scheduler1, scheduler2],
  689. ... milestones=[20],
  690. ... )
  691. >>> for epoch in range(100):
  692. >>> train(...)
  693. >>> validate(...)
  694. >>> scheduler.step()
  695. .. image:: ../scripts/lr_scheduler_images/SequentialLR.png
  696. """
  697. def __init__(
  698. self,
  699. optimizer: Optimizer,
  700. schedulers: list[LRScheduler],
  701. milestones: list[int],
  702. last_epoch: int = -1,
  703. ) -> None: # noqa: D107
  704. if len(schedulers) < 1:
  705. raise ValueError(
  706. f"{self.__class__.__name__} expects at least one scheduler, but got no scheduler."
  707. )
  708. for scheduler_idx, scheduler in enumerate(schedulers):
  709. if not hasattr(scheduler, "optimizer"):
  710. raise TypeError(
  711. f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute."
  712. )
  713. if isinstance(scheduler, ReduceLROnPlateau):
  714. raise ValueError(
  715. f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it "
  716. "requires additional kwargs to be specified when calling `step`, "
  717. f"but got one at index {scheduler_idx} in the given schedulers sequence."
  718. )
  719. if optimizer != scheduler.optimizer:
  720. raise ValueError(
  721. f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but "
  722. f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, "
  723. f"which is different from {optimizer.__class__.__name__}."
  724. )
  725. if len(milestones) != len(schedulers) - 1:
  726. raise ValueError(
  727. "Sequential Schedulers expects number of schedulers provided to be one more "
  728. f"than the number of milestone points, but got number of schedulers {len(schedulers)} and the "
  729. f"number of milestones to be equal to {len(milestones)}"
  730. )
  731. self._schedulers = schedulers
  732. self._milestones = milestones
  733. self.last_epoch = last_epoch + 1
  734. self.optimizer = optimizer
  735. # Reset learning rates back to initial values
  736. for group in self.optimizer.param_groups:
  737. group["lr"] = group["initial_lr"]
  738. # "Undo" the step performed by other schedulers
  739. self.recursive_undo()
  740. # Perform the initial step for only the first scheduler
  741. self._schedulers[0]._initial_step()
  742. self._last_lr = schedulers[0].get_last_lr()
  743. def recursive_undo(self, sched=None):
  744. """
  745. Recursively undo any step performed by the initialisation of
  746. schedulers.
  747. """
  748. scheds = self if sched is None else sched
  749. if hasattr(scheds, "_schedulers"):
  750. for s in scheds._schedulers:
  751. self.recursive_undo(s)
  752. elif hasattr(scheds, "last_epoch"):
  753. scheds.last_epoch -= 1
  754. def step(self) -> None: # type: ignore[override]
  755. """Perform a step."""
  756. self.last_epoch += 1
  757. idx = bisect_right(self._milestones, self.last_epoch)
  758. scheduler = self._schedulers[idx]
  759. if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
  760. scheduler._update_lr(0)
  761. else:
  762. scheduler.step()
  763. self._last_lr = scheduler.get_last_lr()
  764. @override
  765. def state_dict(self) -> dict[str, Any]:
  766. """Return the state of the scheduler as a :class:`dict`.
  767. It contains an entry for every variable in self.__dict__ which
  768. is not the optimizer.
  769. The wrapped scheduler states will also be saved.
  770. """
  771. state_dict = {
  772. key: value
  773. for key, value in self.__dict__.items()
  774. if key not in ("optimizer", "_schedulers")
  775. }
  776. state_dict["_schedulers"] = [None] * len(self._schedulers)
  777. for idx, s in enumerate(self._schedulers):
  778. state_dict["_schedulers"][idx] = s.state_dict()
  779. return state_dict
  780. @override
  781. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  782. """Load the scheduler's state.
  783. Args:
  784. state_dict (dict): scheduler state. Should be an object returned
  785. from a call to :meth:`state_dict`.
  786. """
  787. _schedulers = state_dict.pop("_schedulers")
  788. self.__dict__.update(state_dict)
  789. # Restore state_dict keys in order to prevent side effects
  790. # https://github.com/pytorch/pytorch/issues/32756
  791. state_dict["_schedulers"] = _schedulers
  792. for idx, s in enumerate(_schedulers):
  793. self._schedulers[idx].load_state_dict(s)
  794. class PolynomialLR(LRScheduler):
  795. """Decays the learning rate of each parameter group using a polynomial function in the given total_iters.
  796. When last_epoch=-1, sets initial lr as lr.
  797. Args:
  798. optimizer (Optimizer): Wrapped optimizer.
  799. total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5.
  800. power (float): The power of the polynomial. Default: 1.0.
  801. Example:
  802. >>> # xdoctest: +SKIP("undefined vars")
  803. >>> # Assuming optimizer uses lr = 0.05 for all groups
  804. >>> # lr = 0.0490 if epoch == 0
  805. >>> # lr = 0.0481 if epoch == 1
  806. >>> # lr = 0.0472 if epoch == 2
  807. >>> # ...
  808. >>> # lr = 0.0 if epoch >= 50
  809. >>> scheduler = PolynomialLR(optimizer, total_iters=50, power=0.9)
  810. >>> for epoch in range(100):
  811. >>> train(...)
  812. >>> validate(...)
  813. >>> scheduler.step()
  814. .. image:: ../scripts/lr_scheduler_images/PolynomialLR.png
  815. """
  816. def __init__(
  817. self,
  818. optimizer: Optimizer,
  819. total_iters: int = 5,
  820. power: float = 1.0,
  821. last_epoch: int = -1,
  822. ) -> None: # noqa: D107
  823. self.total_iters = total_iters
  824. self.power = power
  825. super().__init__(optimizer, last_epoch)
  826. @override
  827. def get_lr(self) -> list[float]:
  828. """Compute the learning rate."""
  829. _warn_get_lr_called_within_step(self)
  830. if self._is_initial or self.last_epoch > self.total_iters:
  831. return [group["lr"] for group in self.optimizer.param_groups]
  832. decay_factor = (
  833. (1.0 - self.last_epoch / self.total_iters)
  834. / (1.0 - (self.last_epoch - 1) / self.total_iters)
  835. ) ** self.power
  836. return [group["lr"] * decay_factor for group in self.optimizer.param_groups]
  837. def _get_closed_form_lr(self):
  838. return [
  839. (
  840. base_lr
  841. * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters)
  842. ** self.power
  843. )
  844. for base_lr in self.base_lrs
  845. ]
  846. class CosineAnnealingLR(LRScheduler):
  847. r"""
  848. Set the learning rate of each parameter group using a cosine annealing schedule.
  849. The learning rate is updated recursively using:
  850. .. math::
  851. \eta_{t+1} = \eta_{\min} + (\eta_t - \eta_{\min}) \cdot
  852. \frac{1 + \cos\left(\frac{(T_{cur}+1) \pi}{T_{max}}\right)}
  853. {1 + \cos\left(\frac{T_{cur} \pi}{T_{max}}\right)}
  854. This implements a recursive approximation of the closed-form schedule proposed in
  855. `SGDR: Stochastic Gradient Descent with Warm Restarts`_:
  856. .. math::
  857. \eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min}) \left(
  858. 1 + \cos\left(\frac{T_{cur} \pi}{T_{max}}\right) \right)
  859. where:
  860. - :math:`\eta_t` is the learning rate at step :math:`t`
  861. - :math:`T_{cur}` is the number of epochs since the last restart
  862. - :math:`T_{max}` is the maximum number of epochs in a cycle
  863. Note:
  864. Although SGDR includes periodic restarts, this implementation performs cosine annealing
  865. **without restarts**, so :math:`T_{cur} = t` and increases monotonically with each call
  866. to :meth:`step`.
  867. Args:
  868. optimizer (Optimizer): Wrapped optimizer.
  869. T_max (int): Maximum number of iterations.
  870. eta_min (float): Minimum learning rate. Default: 0.
  871. last_epoch (int): The index of the last epoch. Default: -1.
  872. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
  873. https://arxiv.org/abs/1608.03983
  874. Example:
  875. >>> # xdoctest: +SKIP
  876. >>> num_epochs = 100
  877. >>> scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
  878. >>> for epoch in range(num_epochs):
  879. >>> train(...)
  880. >>> validate(...)
  881. >>> scheduler.step()
  882. .. image:: ../scripts/lr_scheduler_images/CosineAnnealingLR.png
  883. """
  884. def __init__(
  885. self,
  886. optimizer: Optimizer,
  887. T_max: int,
  888. eta_min: float = 0.0,
  889. last_epoch: int = -1,
  890. ) -> None: # noqa: D107
  891. self.T_max = T_max
  892. self.eta_min = eta_min
  893. super().__init__(optimizer, last_epoch)
  894. @override
  895. def get_lr(self) -> list[float]:
  896. """Retrieve the learning rate of each parameter group."""
  897. _warn_get_lr_called_within_step(self)
  898. if self._is_initial:
  899. return [group["lr"] for group in self.optimizer.param_groups]
  900. elif self._step_count == 1 and self.last_epoch > 0:
  901. return [
  902. self.eta_min
  903. + (base_lr - self.eta_min)
  904. * (1 + math.cos((self.last_epoch) * math.pi / self.T_max))
  905. / 2
  906. for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
  907. ]
  908. elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
  909. return [
  910. group["lr"]
  911. + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
  912. for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
  913. ]
  914. return [
  915. (1 + math.cos(math.pi * self.last_epoch / self.T_max))
  916. / (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max))
  917. * (group["lr"] - self.eta_min)
  918. + self.eta_min
  919. for group in self.optimizer.param_groups
  920. ]
  921. def _get_closed_form_lr(self) -> list[float]:
  922. return [
  923. self.eta_min
  924. + (base_lr - self.eta_min)
  925. * (1 + math.cos(math.pi * self.last_epoch / self.T_max))
  926. / 2
  927. for base_lr in self.base_lrs
  928. ]
  929. class ChainedScheduler(LRScheduler):
  930. """Chains a list of learning rate schedulers.
  931. Takes in a sequence of chainable learning rate schedulers and calls their
  932. step() functions consecutively in just one call to step().
  933. Args:
  934. schedulers (sequence): sequence of chained schedulers.
  935. optimizer (Optimizer, optional): Wrapped optimizer. Default: None.
  936. Example:
  937. >>> # xdoctest: +SKIP
  938. >>> # Assuming optimizer uses lr = 0.05 for all groups
  939. >>> # lr = 0.05 if epoch == 0
  940. >>> # lr = 0.0450 if epoch == 1
  941. >>> # lr = 0.0405 if epoch == 2
  942. >>> # ...
  943. >>> # lr = 0.00675 if epoch == 19
  944. >>> # lr = 0.06078 if epoch == 20
  945. >>> # lr = 0.05470 if epoch == 21
  946. >>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=20)
  947. >>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
  948. >>> scheduler = ChainedScheduler([scheduler1, scheduler2], optimizer=optimizer)
  949. >>> for epoch in range(100):
  950. >>> train(...)
  951. >>> validate(...)
  952. >>> scheduler.step()
  953. .. image:: ../scripts/lr_scheduler_images/ChainedScheduler.png
  954. """
  955. def __init__(
  956. self, schedulers: Sequence[LRScheduler], optimizer: Optional[Optimizer] = None
  957. ) -> None: # noqa: D107
  958. if len(schedulers) < 1:
  959. raise ValueError(
  960. f"{self.__class__.__name__} expects at least one scheduler to be chained, but got no scheduler."
  961. )
  962. optimizer = optimizer or schedulers[0].optimizer
  963. for scheduler_idx, scheduler in enumerate(schedulers):
  964. if not hasattr(scheduler, "optimizer"):
  965. raise TypeError(
  966. f"{self.__class__.__name__} at index {scheduler_idx} should have `optimizer` as its attribute."
  967. )
  968. if isinstance(scheduler, ReduceLROnPlateau):
  969. raise ValueError(
  970. f"{self.__class__.__name__} does not support `ReduceLROnPlateau` scheduler as it "
  971. "requires additional kwargs to be specified when calling `step`, "
  972. f"but got one at index {scheduler_idx} in the given schedulers sequence."
  973. )
  974. if optimizer != scheduler.optimizer:
  975. raise ValueError(
  976. f"{self.__class__.__name__} expects all schedulers to belong to the same optimizer, but "
  977. f"got scheduler {scheduler.__class__.__name__} at index {scheduler_idx} has {scheduler.optimizer}, "
  978. f"which is different from {optimizer.__class__.__name__}."
  979. )
  980. self._schedulers = schedulers
  981. self.optimizer = optimizer
  982. self._last_lr = [
  983. group["lr"] for group in self._schedulers[-1].optimizer.param_groups
  984. ]
  985. def step(self) -> None: # type: ignore[override]
  986. """Perform a step."""
  987. for scheduler in self._schedulers:
  988. scheduler.step()
  989. self._last_lr = [
  990. group["lr"] for group in self._schedulers[-1].optimizer.param_groups
  991. ]
  992. @override
  993. def state_dict(self) -> dict[str, Any]:
  994. """Return the state of the scheduler as a :class:`dict`.
  995. It contains an entry for every variable in self.__dict__ which
  996. is not the optimizer.
  997. The wrapped scheduler states will also be saved.
  998. """
  999. state_dict = {
  1000. key: value
  1001. for key, value in self.__dict__.items()
  1002. if key not in ("optimizer", "_schedulers")
  1003. }
  1004. state_dict["_schedulers"] = [None] * len(self._schedulers)
  1005. for idx, s in enumerate(self._schedulers):
  1006. state_dict["_schedulers"][idx] = s.state_dict()
  1007. return state_dict
  1008. @override
  1009. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  1010. """Load the scheduler's state.
  1011. Args:
  1012. state_dict (dict): scheduler state. Should be an object returned
  1013. from a call to :meth:`state_dict`.
  1014. """
  1015. _schedulers = state_dict.pop("_schedulers")
  1016. self.__dict__.update(state_dict)
  1017. # Restore state_dict keys in order to prevent side effects
  1018. # https://github.com/pytorch/pytorch/issues/32756
  1019. state_dict["_schedulers"] = _schedulers
  1020. for idx, s in enumerate(_schedulers):
  1021. self._schedulers[idx].load_state_dict(s)
  1022. class ReduceLROnPlateau(LRScheduler):
  1023. """Reduce learning rate when a metric has stopped improving.
  1024. Models often benefit from reducing the learning rate by a factor
  1025. of 2-10 once learning stagnates. This scheduler reads a metrics
  1026. quantity and if no improvement is seen for a 'patience' number
  1027. of epochs, the learning rate is reduced.
  1028. Args:
  1029. optimizer (Optimizer): Wrapped optimizer.
  1030. mode (str): One of `min`, `max`. In `min` mode, lr will
  1031. be reduced when the quantity monitored has stopped
  1032. decreasing; in `max` mode it will be reduced when the
  1033. quantity monitored has stopped increasing. Default: 'min'.
  1034. factor (float): Factor by which the learning rate will be
  1035. reduced. new_lr = lr * factor. Default: 0.1.
  1036. patience (int): The number of allowed epochs with no improvement after
  1037. which the learning rate will be reduced.
  1038. For example, consider the case of having no patience (`patience = 0`).
  1039. In the first epoch, a baseline is established and is always considered good as there's no previous baseline.
  1040. In the second epoch, if the performance is worse than the baseline,
  1041. we have what is considered an intolerable epoch.
  1042. Since the count of intolerable epochs (1) is greater than the patience level (0),
  1043. the learning rate is reduced at the end of this epoch.
  1044. From the third epoch onwards, the learning rate continues to be reduced at the end of each epoch
  1045. if the performance is worse than the baseline. If the performance improves or remains the same,
  1046. the learning rate is not adjusted.
  1047. Default: 10.
  1048. threshold (float): Threshold for measuring the new optimum,
  1049. to only focus on significant changes. Default: 1e-4.
  1050. threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
  1051. dynamic_threshold = best * ( 1 + threshold ) in 'max'
  1052. mode or best * ( 1 - threshold ) in `min` mode.
  1053. In `abs` mode, dynamic_threshold = best + threshold in
  1054. `max` mode or best - threshold in `min` mode. Default: 'rel'.
  1055. cooldown (int): Number of epochs to wait before resuming
  1056. normal operation after lr has been reduced. Default: 0.
  1057. min_lr (float or list): A scalar or a list of scalars. A
  1058. lower bound on the learning rate of all param groups
  1059. or each group respectively. Default: 0.
  1060. eps (float): Minimal decay applied to lr. If the difference
  1061. between new and old lr is smaller than eps, the update is
  1062. ignored. Default: 1e-8.
  1063. Example:
  1064. >>> # xdoctest: +SKIP
  1065. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  1066. >>> scheduler = ReduceLROnPlateau(optimizer, "min")
  1067. >>> for epoch in range(10):
  1068. >>> train(...)
  1069. >>> val_loss = validate(...)
  1070. >>> # Note that step should be called after validate()
  1071. >>> scheduler.step(val_loss)
  1072. .. image:: ../scripts/lr_scheduler_images/ReduceLROnPlateau.png
  1073. """
  1074. def __init__(
  1075. self,
  1076. optimizer: Optimizer,
  1077. mode: Literal["min", "max"] = "min",
  1078. factor: float = 0.1,
  1079. patience: int = 10,
  1080. threshold: float = 1e-4,
  1081. threshold_mode: Literal["rel", "abs"] = "rel",
  1082. cooldown: int = 0,
  1083. min_lr: Union[list[float], float] = 0,
  1084. eps: float = 1e-8,
  1085. ): # noqa: D107
  1086. if factor >= 1.0:
  1087. raise ValueError("Factor should be < 1.0.")
  1088. self.factor = factor
  1089. # Attach optimizer
  1090. if not isinstance(optimizer, Optimizer):
  1091. raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
  1092. self.optimizer = optimizer
  1093. if isinstance(min_lr, (list, tuple)):
  1094. if len(min_lr) != len(optimizer.param_groups):
  1095. raise ValueError(
  1096. f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}"
  1097. )
  1098. self.default_min_lr = None
  1099. self.min_lrs = list(min_lr)
  1100. else:
  1101. self.default_min_lr = min_lr
  1102. self.min_lrs = [min_lr] * len(optimizer.param_groups)
  1103. self.patience = patience
  1104. self.cooldown = cooldown
  1105. self.eps = eps
  1106. self.last_epoch = 0
  1107. self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
  1108. self._init_is_better(
  1109. mode=mode, threshold=threshold, threshold_mode=threshold_mode
  1110. )
  1111. self._reset()
  1112. def _reset(self):
  1113. """Reset num_bad_epochs counter and cooldown counter."""
  1114. self.best = self.mode_worse
  1115. self.cooldown_counter = 0
  1116. self.num_bad_epochs = 0
  1117. def step(self, metrics: SupportsFloat, epoch=None) -> None: # type: ignore[override]
  1118. """Perform a step."""
  1119. # convert `metrics` to float, in case it's a zero-dim Tensor
  1120. current = float(metrics)
  1121. if epoch is None:
  1122. epoch = self.last_epoch + 1
  1123. else:
  1124. warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
  1125. self.last_epoch = epoch
  1126. if self._is_better(current, self.best):
  1127. self.best = current
  1128. self.num_bad_epochs = 0
  1129. else:
  1130. self.num_bad_epochs += 1
  1131. if self.in_cooldown:
  1132. self.cooldown_counter -= 1
  1133. self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
  1134. if self.num_bad_epochs > self.patience:
  1135. self._reduce_lr(epoch)
  1136. self.cooldown_counter = self.cooldown
  1137. self.num_bad_epochs = 0
  1138. self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
  1139. def _reduce_lr(self, epoch):
  1140. if len(self.optimizer.param_groups) != len(self.min_lrs):
  1141. if self.default_min_lr is None:
  1142. raise RuntimeError(
  1143. "The number of param groups in the `optimizer` "
  1144. f"({len(self.optimizer.param_groups)}) differs "
  1145. f"from when `ReduceLROnPlateau` was initialized "
  1146. f"({len(self.min_lrs)}), usually due to a new "
  1147. "param group being added to the optimizer. Please "
  1148. "modify the `min_lrs` field to match the length "
  1149. "of the `optimizer` param groups."
  1150. )
  1151. else:
  1152. self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups)
  1153. for i, param_group in enumerate(self.optimizer.param_groups):
  1154. old_lr = float(param_group["lr"])
  1155. new_lr = max(old_lr * self.factor, self.min_lrs[i])
  1156. if old_lr - new_lr > self.eps:
  1157. param_group["lr"] = new_lr
  1158. @property
  1159. def in_cooldown(self): # noqa: D102
  1160. return self.cooldown_counter > 0
  1161. def _is_better(self, a, best): # noqa: D102
  1162. if self.mode == "min" and self.threshold_mode == "rel":
  1163. rel_epsilon = 1.0 - self.threshold
  1164. return a < best * rel_epsilon
  1165. elif self.mode == "min" and self.threshold_mode == "abs":
  1166. return a < best - self.threshold
  1167. elif self.mode == "max" and self.threshold_mode == "rel":
  1168. rel_epsilon = self.threshold + 1.0
  1169. return a > best * rel_epsilon
  1170. else: # mode == 'max' and epsilon_mode == 'abs':
  1171. return a > best + self.threshold
  1172. def _init_is_better(self, mode, threshold, threshold_mode):
  1173. if mode not in {"min", "max"}:
  1174. raise ValueError("mode " + mode + " is unknown!")
  1175. if threshold_mode not in {"rel", "abs"}:
  1176. raise ValueError("threshold mode " + threshold_mode + " is unknown!")
  1177. # the worse value for the chosen mode
  1178. if mode == "min":
  1179. self.mode_worse = inf
  1180. else: # mode == 'max':
  1181. self.mode_worse = -inf
  1182. self.mode = mode
  1183. self.threshold = threshold
  1184. self.threshold_mode = threshold_mode
  1185. @override
  1186. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  1187. """Load the scheduler's state."""
  1188. self.__dict__.update(state_dict)
  1189. self._init_is_better(
  1190. mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode
  1191. )
  1192. class CyclicLR(LRScheduler):
  1193. r"""Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR).
  1194. The policy cycles the learning rate between two boundaries with a constant frequency,
  1195. as detailed in the paper `Cyclical Learning Rates for Training Neural Networks`_.
  1196. The distance between the two boundaries can be scaled on a per-iteration
  1197. or per-cycle basis.
  1198. Cyclical learning rate policy changes the learning rate after every batch.
  1199. `step` should be called after a batch has been used for training.
  1200. This class has three built-in policies, as put forth in the paper:
  1201. * "triangular": A basic triangular cycle without amplitude scaling.
  1202. * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
  1203. * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
  1204. at each cycle iteration.
  1205. This implementation was adapted from the github repo: `bckenstler/CLR`_
  1206. Args:
  1207. optimizer (Optimizer): Wrapped optimizer.
  1208. base_lr (float or list): Initial learning rate which is the
  1209. lower boundary in the cycle for each parameter group.
  1210. max_lr (float or list): Upper learning rate boundaries in the cycle
  1211. for each parameter group. Functionally,
  1212. it defines the cycle amplitude (max_lr - base_lr).
  1213. The lr at any cycle is the sum of base_lr
  1214. and some scaling of the amplitude; therefore
  1215. max_lr may not actually be reached depending on
  1216. scaling function.
  1217. step_size_up (int): Number of training iterations in the
  1218. increasing half of a cycle. Default: 2000
  1219. step_size_down (int): Number of training iterations in the
  1220. decreasing half of a cycle. If step_size_down is None,
  1221. it is set to step_size_up. Default: None
  1222. mode (str): One of {triangular, triangular2, exp_range}.
  1223. Values correspond to policies detailed above.
  1224. If scale_fn is not None, this argument is ignored.
  1225. Default: 'triangular'
  1226. gamma (float): Constant in 'exp_range' scaling function:
  1227. gamma**(cycle iterations)
  1228. Default: 1.0
  1229. scale_fn (function): Custom scaling policy defined by a single
  1230. argument lambda function, where
  1231. 0 <= scale_fn(x) <= 1 for all x >= 0.
  1232. If specified, then 'mode' is ignored.
  1233. Default: None
  1234. scale_mode (str): {'cycle', 'iterations'}.
  1235. Defines whether scale_fn is evaluated on
  1236. cycle number or cycle iterations (training
  1237. iterations since start of cycle).
  1238. Default: 'cycle'
  1239. cycle_momentum (bool): If ``True``, momentum is cycled inversely
  1240. to learning rate between 'base_momentum' and 'max_momentum'.
  1241. Default: True
  1242. base_momentum (float or list): Lower momentum boundaries in the cycle
  1243. for each parameter group. Note that momentum is cycled inversely
  1244. to learning rate; at the peak of a cycle, momentum is
  1245. 'base_momentum' and learning rate is 'max_lr'.
  1246. Default: 0.8
  1247. max_momentum (float or list): Upper momentum boundaries in the cycle
  1248. for each parameter group. Functionally,
  1249. it defines the cycle amplitude (max_momentum - base_momentum).
  1250. The momentum at any cycle is the difference of max_momentum
  1251. and some scaling of the amplitude; therefore
  1252. base_momentum may not actually be reached depending on
  1253. scaling function. Note that momentum is cycled inversely
  1254. to learning rate; at the start of a cycle, momentum is 'max_momentum'
  1255. and learning rate is 'base_lr'
  1256. Default: 0.9
  1257. last_epoch (int): The index of the last batch. This parameter is used when
  1258. resuming a training job. Since `step()` should be invoked after each
  1259. batch instead of after each epoch, this number represents the total
  1260. number of *batches* computed, not the total number of epochs computed.
  1261. When last_epoch=-1, the schedule is started from the beginning.
  1262. Default: -1
  1263. Example:
  1264. >>> # xdoctest: +SKIP
  1265. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  1266. >>> scheduler = torch.optim.lr_scheduler.CyclicLR(
  1267. ... optimizer,
  1268. ... base_lr=0.01,
  1269. ... max_lr=0.1,
  1270. ... step_size_up=10,
  1271. ... )
  1272. >>> data_loader = torch.utils.data.DataLoader(...)
  1273. >>> for epoch in range(10):
  1274. >>> for batch in data_loader:
  1275. >>> train_batch(...)
  1276. >>> scheduler.step()
  1277. .. image:: ../scripts/lr_scheduler_images/CyclicLR.png
  1278. .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
  1279. .. _bckenstler/CLR: https://github.com/bckenstler/CLR
  1280. """
  1281. def __init__(
  1282. self,
  1283. optimizer: Optimizer,
  1284. base_lr: Union[float, list[float]],
  1285. max_lr: Union[float, list[float]],
  1286. step_size_up: int = 2000,
  1287. step_size_down: Optional[int] = None,
  1288. mode: Literal["triangular", "triangular2", "exp_range"] = "triangular",
  1289. gamma: float = 1.0,
  1290. scale_fn: Optional[Callable[[float], float]] = None,
  1291. scale_mode: Literal["cycle", "iterations"] = "cycle",
  1292. cycle_momentum: bool = True,
  1293. base_momentum: float = 0.8,
  1294. max_momentum: float = 0.9,
  1295. last_epoch: int = -1,
  1296. ): # noqa: D107
  1297. # Attach optimizer
  1298. if not isinstance(optimizer, Optimizer):
  1299. raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
  1300. self.optimizer = optimizer
  1301. base_lrs = _format_param("base_lr", optimizer, base_lr)
  1302. if last_epoch == -1:
  1303. for lr, group in zip(base_lrs, optimizer.param_groups):
  1304. if isinstance(group["lr"], Tensor):
  1305. lr_val = lr.item() if isinstance(lr, Tensor) else lr
  1306. group["lr"].fill_(lr_val)
  1307. else:
  1308. group["lr"] = lr
  1309. self.max_lrs = _format_param("max_lr", optimizer, max_lr)
  1310. step_size_up = float(step_size_up)
  1311. step_size_down = (
  1312. float(step_size_down) if step_size_down is not None else step_size_up
  1313. )
  1314. self.total_size = step_size_up + step_size_down
  1315. self.step_ratio = step_size_up / self.total_size
  1316. if mode not in ["triangular", "triangular2", "exp_range"] and scale_fn is None:
  1317. raise ValueError("mode is invalid and scale_fn is None")
  1318. self.mode = mode
  1319. self.gamma = gamma
  1320. self._scale_fn_ref: Callable[[float], float]
  1321. self._scale_fn_custom = scale_fn
  1322. self.scale_mode = scale_mode
  1323. self._init_scale_fn()
  1324. self.cycle_momentum = cycle_momentum
  1325. if cycle_momentum:
  1326. if (
  1327. "momentum" not in optimizer.defaults
  1328. and "betas" not in optimizer.defaults
  1329. ):
  1330. raise ValueError(
  1331. "optimizer must support momentum or beta1 with `cycle_momentum` option enabled"
  1332. )
  1333. self.use_beta1 = "betas" in self.optimizer.defaults
  1334. self.base_momentums = _format_param(
  1335. "base_momentum", optimizer, base_momentum
  1336. )
  1337. self.max_momentums = _format_param("max_momentum", optimizer, max_momentum)
  1338. if last_epoch == -1:
  1339. for m_momentum, b_momentum, group in zip(
  1340. self.max_momentums, self.base_momentums, optimizer.param_groups
  1341. ):
  1342. if self.use_beta1:
  1343. group["betas"] = (m_momentum, *group["betas"][1:])
  1344. else:
  1345. group["momentum"] = m_momentum
  1346. group["max_momentum"] = m_momentum
  1347. group["base_momentum"] = b_momentum
  1348. super().__init__(optimizer, last_epoch)
  1349. self.base_lrs = base_lrs
  1350. def _init_scale_fn(self):
  1351. if self._scale_fn_custom is not None:
  1352. return
  1353. if self.mode == "triangular":
  1354. self._scale_fn_ref = self._triangular_scale_fn
  1355. self.scale_mode = "cycle"
  1356. elif self.mode == "triangular2":
  1357. self._scale_fn_ref = self._triangular2_scale_fn
  1358. self.scale_mode = "cycle"
  1359. elif self.mode == "exp_range":
  1360. self._scale_fn_ref = partial(self._exp_range_scale_fn, self.gamma)
  1361. self.scale_mode = "iterations"
  1362. def scale_fn(self, x) -> float:
  1363. """Get the scaling policy."""
  1364. if self._scale_fn_custom is not None:
  1365. return self._scale_fn_custom(x)
  1366. else:
  1367. return self._scale_fn_ref(x) # static method
  1368. @staticmethod
  1369. def _triangular_scale_fn(x: float) -> float:
  1370. return 1.0
  1371. @staticmethod
  1372. def _triangular2_scale_fn(x: float) -> float:
  1373. return 1 / (2.0 ** (x - 1))
  1374. @staticmethod
  1375. def _exp_range_scale_fn(gamma: float, x: float) -> float:
  1376. return gamma**x
  1377. @override
  1378. def get_lr(self) -> list[float]:
  1379. """Calculate the learning rate at batch index.
  1380. This function treats `self.last_epoch` as the last batch index.
  1381. If `self.cycle_momentum` is ``True``, this function has a side effect of
  1382. updating the optimizer's momentum.
  1383. """
  1384. _warn_get_lr_called_within_step(self)
  1385. cycle = math.floor(1 + self.last_epoch / self.total_size)
  1386. x = 1.0 + self.last_epoch / self.total_size - cycle
  1387. if x <= self.step_ratio:
  1388. scale_factor = x / self.step_ratio
  1389. else:
  1390. scale_factor = (x - 1) / (self.step_ratio - 1)
  1391. lrs = []
  1392. for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
  1393. base_height = (max_lr - base_lr) * scale_factor
  1394. if self.scale_mode == "cycle":
  1395. lr = base_lr + base_height * self.scale_fn(cycle)
  1396. else:
  1397. lr = base_lr + base_height * self.scale_fn(self.last_epoch)
  1398. lrs.append(lr)
  1399. if self.cycle_momentum:
  1400. momentums = []
  1401. for base_momentum, max_momentum in zip(
  1402. self.base_momentums, self.max_momentums
  1403. ):
  1404. base_height = (max_momentum - base_momentum) * scale_factor
  1405. if self.scale_mode == "cycle":
  1406. momentum = max_momentum - base_height * self.scale_fn(cycle)
  1407. else:
  1408. momentum = max_momentum - base_height * self.scale_fn(
  1409. self.last_epoch
  1410. )
  1411. momentums.append(momentum)
  1412. for param_group, momentum in zip(self.optimizer.param_groups, momentums):
  1413. if self.use_beta1:
  1414. param_group["betas"] = (momentum, *param_group["betas"][1:])
  1415. else:
  1416. param_group["momentum"] = momentum
  1417. return lrs
  1418. @override
  1419. def state_dict(self) -> dict[str, Any]: # noqa: D102
  1420. """Return the state of the scheduler as a :class:`dict`.
  1421. It contains an entry for every variable in self.__dict__ which
  1422. is not the optimizer.
  1423. The learning rate lambda functions will only be saved if they are callable objects
  1424. and not if they are functions or lambdas.
  1425. When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
  1426. """
  1427. state = super().state_dict()
  1428. # We are dropping the `_scale_fn_ref` attribute because it is a
  1429. # `weakref.WeakMethod` and can't be pickled.
  1430. state.pop("_scale_fn_ref", None)
  1431. fn = state.pop("_scale_fn_custom")
  1432. state["_scale_fn_custom"] = None
  1433. if fn is not None and not isinstance(fn, types.FunctionType):
  1434. # The _scale_fn_custom will only be saved if it is a callable object
  1435. # and not if it is a function or lambda.
  1436. state["_scale_fn_custom"] = fn.__dict__.copy()
  1437. return state
  1438. @override
  1439. def load_state_dict(self, state_dict: dict[str, Any]) -> None:
  1440. """Load the scheduler's state."""
  1441. fn = state_dict.pop("_scale_fn_custom")
  1442. super().load_state_dict(state_dict)
  1443. if fn is not None:
  1444. self._scale_fn_custom.__dict__.update(fn)
  1445. self._init_scale_fn()
  1446. class CosineAnnealingWarmRestarts(LRScheduler):
  1447. r"""Set the learning rate of each parameter group using a cosine annealing schedule.
  1448. The :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
  1449. is the number of epochs since the last restart and :math:`T_{i}` is the number
  1450. of epochs between two warm restarts in SGDR:
  1451. .. math::
  1452. \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
  1453. \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
  1454. When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
  1455. When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
  1456. It has been proposed in
  1457. `SGDR: Stochastic Gradient Descent with Warm Restarts`_.
  1458. Args:
  1459. optimizer (Optimizer): Wrapped optimizer.
  1460. T_0 (int): Number of iterations until the first restart.
  1461. T_mult (int, optional): A factor by which :math:`T_{i}` increases after a restart. Default: 1.
  1462. eta_min (float, optional): Minimum learning rate. Default: 0.
  1463. last_epoch (int, optional): The index of the last epoch. Default: -1.
  1464. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
  1465. https://arxiv.org/abs/1608.03983
  1466. Example:
  1467. >>> # xdoctest: +SKIP
  1468. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
  1469. >>> scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
  1470. ... optimizer, T_0=20
  1471. ... )
  1472. >>> for epoch in range(100):
  1473. >>> train(...)
  1474. >>> validate(...)
  1475. >>> scheduler.step()
  1476. .. image:: ../scripts/lr_scheduler_images/CosineAnnealingWarmRestarts.png
  1477. """
  1478. def __init__(
  1479. self,
  1480. optimizer: Optimizer,
  1481. T_0: int,
  1482. T_mult: int = 1,
  1483. eta_min: float = 0.0,
  1484. last_epoch: int = -1,
  1485. ): # noqa: D107
  1486. if T_0 <= 0 or not isinstance(T_0, int):
  1487. raise ValueError(f"Expected positive integer T_0, but got {T_0}")
  1488. if T_mult < 1 or not isinstance(T_mult, int):
  1489. raise ValueError(f"Expected integer T_mult >= 1, but got {T_mult}")
  1490. if not isinstance(eta_min, (float, int)):
  1491. raise ValueError(
  1492. f"Expected float or int eta_min, but got {eta_min} of type {type(eta_min)}"
  1493. )
  1494. self.T_0 = T_0
  1495. self.T_i = T_0
  1496. self.T_mult = T_mult
  1497. self.eta_min = eta_min
  1498. self.T_cur = last_epoch
  1499. super().__init__(optimizer, last_epoch)
  1500. @override
  1501. def get_lr(self) -> list[float]:
  1502. """Compute the initial learning rate."""
  1503. _warn_get_lr_called_within_step(self)
  1504. return [
  1505. self.eta_min
  1506. + (base_lr - self.eta_min)
  1507. * (1 + math.cos(math.pi * self.T_cur / self.T_i))
  1508. / 2
  1509. for base_lr in self.base_lrs
  1510. ]
  1511. @override
  1512. def step(self, epoch=None) -> None:
  1513. """Step could be called after every batch update.
  1514. Example:
  1515. >>> # xdoctest: +SKIP("Undefined vars")
  1516. >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
  1517. >>> iters = len(dataloader)
  1518. >>> for epoch in range(20):
  1519. >>> for i, sample in enumerate(dataloader):
  1520. >>> inputs, labels = sample['inputs'], sample['labels']
  1521. >>> optimizer.zero_grad()
  1522. >>> outputs = net(inputs)
  1523. >>> loss = criterion(outputs, labels)
  1524. >>> loss.backward()
  1525. >>> optimizer.step()
  1526. >>> scheduler.step(epoch + i / iters)
  1527. This function can be called in an interleaved way.
  1528. Example:
  1529. >>> # xdoctest: +SKIP("Undefined vars")
  1530. >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
  1531. >>> for epoch in range(20):
  1532. >>> scheduler.step()
  1533. >>> scheduler.step(26)
  1534. >>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
  1535. """
  1536. if epoch is None and self.last_epoch < 0:
  1537. epoch = 0
  1538. if epoch is None:
  1539. epoch = self.last_epoch + 1
  1540. self.T_cur = self.T_cur + 1
  1541. if self.T_cur >= self.T_i:
  1542. self.T_cur = self.T_cur % self.T_i
  1543. self.T_i = self.T_i * self.T_mult
  1544. else:
  1545. if epoch < 0:
  1546. raise ValueError(f"Expected non-negative epoch, but got {epoch}")
  1547. if epoch >= self.T_0:
  1548. if self.T_mult == 1:
  1549. self.T_cur = epoch % self.T_0
  1550. else:
  1551. n = int(
  1552. math.log(
  1553. (epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult
  1554. )
  1555. )
  1556. self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / (
  1557. self.T_mult - 1
  1558. )
  1559. self.T_i = self.T_0 * self.T_mult ** (n)
  1560. else:
  1561. self.T_i = self.T_0
  1562. self.T_cur = epoch
  1563. self.last_epoch = math.floor(epoch)
  1564. with _enable_get_lr_call(self):
  1565. for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
  1566. param_group["lr"] = lr
  1567. self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
  1568. class _SchedulePhase(TypedDict):
  1569. end_step: float
  1570. start_lr: str
  1571. end_lr: str
  1572. start_momentum: str
  1573. end_momentum: str
  1574. class OneCycleLR(LRScheduler):
  1575. r"""Sets the learning rate of each parameter group according to the 1cycle learning rate policy.
  1576. The 1cycle policy anneals the learning rate from an initial learning rate to some maximum
  1577. learning rate and then from that maximum learning rate to some minimum learning rate much
  1578. lower than the initial learning rate.
  1579. This policy was initially described in the paper `Super-Convergence:
  1580. Very Fast Training of Neural Networks Using Large Learning Rates`_.
  1581. The 1cycle learning rate policy changes the learning rate after every batch.
  1582. `step` should be called after a batch has been used for training.
  1583. This scheduler is not chainable.
  1584. Note also that the total number of steps in the cycle can be determined in one
  1585. of two ways (listed in order of precedence):
  1586. #. A value for total_steps is explicitly provided.
  1587. #. A number of epochs (epochs) and a number of steps per epoch
  1588. (steps_per_epoch) are provided.
  1589. In this case, the number of total steps is inferred by
  1590. total_steps = epochs * steps_per_epoch
  1591. You must either provide a value for total_steps or provide a value for both
  1592. epochs and steps_per_epoch.
  1593. The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
  1594. claims that "unpublished work has shown even better results by using only two phases". To
  1595. mimic the behaviour of the original paper instead, set ``three_phase=True``.
  1596. Args:
  1597. optimizer (Optimizer): Wrapped optimizer.
  1598. max_lr (float or list): Upper learning rate boundaries in the cycle
  1599. for each parameter group.
  1600. total_steps (int): The total number of steps in the cycle. Note that
  1601. if a value is not provided here, then it must be inferred by providing
  1602. a value for epochs and steps_per_epoch.
  1603. Default: None
  1604. epochs (int): The number of epochs to train for. This is used along
  1605. with steps_per_epoch in order to infer the total number of steps in the cycle
  1606. if a value for total_steps is not provided.
  1607. Default: None
  1608. steps_per_epoch (int): The number of steps per epoch to train for. This is
  1609. used along with epochs in order to infer the total number of steps in the
  1610. cycle if a value for total_steps is not provided.
  1611. Default: None
  1612. pct_start (float): The percentage of the cycle (in number of steps) spent
  1613. increasing the learning rate.
  1614. Default: 0.3
  1615. anneal_strategy (str): {'cos', 'linear'}
  1616. Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
  1617. linear annealing.
  1618. Default: 'cos'
  1619. cycle_momentum (bool): If ``True``, momentum is cycled inversely
  1620. to learning rate between 'base_momentum' and 'max_momentum'.
  1621. Default: True
  1622. base_momentum (float or list): Lower momentum boundaries in the cycle
  1623. for each parameter group. Note that momentum is cycled inversely
  1624. to learning rate; at the peak of a cycle, momentum is
  1625. 'base_momentum' and learning rate is 'max_lr'.
  1626. Default: 0.85
  1627. max_momentum (float or list): Upper momentum boundaries in the cycle
  1628. for each parameter group. Functionally,
  1629. it defines the cycle amplitude (max_momentum - base_momentum).
  1630. Note that momentum is cycled inversely
  1631. to learning rate; at the start of a cycle, momentum is 'max_momentum'
  1632. and learning rate is 'base_lr'
  1633. Default: 0.95
  1634. div_factor (float): Determines the initial learning rate via
  1635. initial_lr = max_lr/div_factor
  1636. Default: 25
  1637. final_div_factor (float): Determines the minimum learning rate via
  1638. min_lr = initial_lr/final_div_factor
  1639. Default: 1e4
  1640. three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the
  1641. learning rate according to 'final_div_factor' instead of modifying the second
  1642. phase (the first two phases will be symmetrical about the step indicated by
  1643. 'pct_start').
  1644. last_epoch (int): The index of the last batch. This parameter is used when
  1645. resuming a training job. Since `step()` should be invoked after each
  1646. batch instead of after each epoch, this number represents the total
  1647. number of *batches* computed, not the total number of epochs computed.
  1648. When last_epoch=-1, the schedule is started from the beginning.
  1649. Default: -1
  1650. Example:
  1651. >>> # xdoctest: +SKIP
  1652. >>> data_loader = torch.utils.data.DataLoader(...)
  1653. >>> optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
  1654. >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(
  1655. ... optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10
  1656. ... )
  1657. >>> for epoch in range(10):
  1658. >>> for batch in data_loader:
  1659. >>> train_batch(...)
  1660. >>> optimizer.step()
  1661. >>> scheduler.step()
  1662. .. image:: ../scripts/lr_scheduler_images/OneCycleLR.png
  1663. .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
  1664. https://arxiv.org/abs/1708.07120
  1665. """
  1666. def __init__(
  1667. self,
  1668. optimizer: Optimizer,
  1669. max_lr: Union[float, list[float]],
  1670. total_steps: Optional[int] = None,
  1671. epochs: Optional[int] = None,
  1672. steps_per_epoch: Optional[int] = None,
  1673. pct_start: float = 0.3,
  1674. anneal_strategy: Literal["cos", "linear"] = "cos",
  1675. cycle_momentum: bool = True,
  1676. base_momentum: Union[float, list[float]] = 0.85,
  1677. max_momentum: Union[float, list[float]] = 0.95,
  1678. div_factor: float = 25.0,
  1679. final_div_factor: float = 1e4,
  1680. three_phase: bool = False,
  1681. last_epoch: int = -1,
  1682. ): # noqa: D107
  1683. # Validate optimizer
  1684. if not isinstance(optimizer, Optimizer):
  1685. raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
  1686. self.optimizer = optimizer
  1687. # Validate total_steps
  1688. if total_steps is not None:
  1689. if total_steps <= 0 or not isinstance(total_steps, int):
  1690. raise ValueError(
  1691. f"Expected positive integer total_steps, but got {total_steps}"
  1692. )
  1693. self.total_steps = total_steps
  1694. elif epochs is not None and steps_per_epoch is not None:
  1695. if not isinstance(epochs, int) or epochs <= 0:
  1696. raise ValueError(f"Expected positive integer epochs, but got {epochs}")
  1697. if not isinstance(steps_per_epoch, int) or steps_per_epoch <= 0:
  1698. raise ValueError(
  1699. f"Expected positive integer steps_per_epoch, but got {steps_per_epoch}"
  1700. )
  1701. self.total_steps = epochs * steps_per_epoch
  1702. else:
  1703. raise ValueError(
  1704. "You must define either total_steps OR (epochs AND steps_per_epoch)"
  1705. )
  1706. self._schedule_phases: list[_SchedulePhase]
  1707. if three_phase:
  1708. self._schedule_phases = [
  1709. {
  1710. "end_step": float(pct_start * self.total_steps) - 1,
  1711. "start_lr": "initial_lr",
  1712. "end_lr": "max_lr",
  1713. "start_momentum": "max_momentum",
  1714. "end_momentum": "base_momentum",
  1715. },
  1716. {
  1717. "end_step": float(2 * pct_start * self.total_steps) - 2,
  1718. "start_lr": "max_lr",
  1719. "end_lr": "initial_lr",
  1720. "start_momentum": "base_momentum",
  1721. "end_momentum": "max_momentum",
  1722. },
  1723. {
  1724. "end_step": self.total_steps - 1,
  1725. "start_lr": "initial_lr",
  1726. "end_lr": "min_lr",
  1727. "start_momentum": "max_momentum",
  1728. "end_momentum": "max_momentum",
  1729. },
  1730. ]
  1731. else:
  1732. self._schedule_phases = [
  1733. {
  1734. "end_step": float(pct_start * self.total_steps) - 1,
  1735. "start_lr": "initial_lr",
  1736. "end_lr": "max_lr",
  1737. "start_momentum": "max_momentum",
  1738. "end_momentum": "base_momentum",
  1739. },
  1740. {
  1741. "end_step": self.total_steps - 1,
  1742. "start_lr": "max_lr",
  1743. "end_lr": "min_lr",
  1744. "start_momentum": "base_momentum",
  1745. "end_momentum": "max_momentum",
  1746. },
  1747. ]
  1748. # Validate pct_start
  1749. if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
  1750. raise ValueError(
  1751. f"Expected float between 0 and 1 pct_start, but got {pct_start}"
  1752. )
  1753. # Validate anneal_strategy
  1754. if anneal_strategy not in ["cos", "linear"]:
  1755. raise ValueError(
  1756. f"anneal_strategy must be one of 'cos' or 'linear', instead got {anneal_strategy}"
  1757. )
  1758. else:
  1759. self._anneal_func_type = anneal_strategy
  1760. # Initialize learning rate variables
  1761. max_lrs = _format_param("max_lr", self.optimizer, max_lr)
  1762. if last_epoch == -1:
  1763. for idx, group in enumerate(self.optimizer.param_groups):
  1764. group["initial_lr"] = max_lrs[idx] / div_factor
  1765. group["max_lr"] = max_lrs[idx]
  1766. group["min_lr"] = group["initial_lr"] / final_div_factor
  1767. # Initialize momentum variables
  1768. self.cycle_momentum = cycle_momentum
  1769. if self.cycle_momentum:
  1770. if (
  1771. "momentum" not in self.optimizer.defaults
  1772. and "betas" not in self.optimizer.defaults
  1773. ):
  1774. raise ValueError(
  1775. "optimizer must support momentum or beta1 with `cycle_momentum` option enabled"
  1776. )
  1777. self.use_beta1 = "betas" in self.optimizer.defaults
  1778. max_momentums = _format_param("max_momentum", optimizer, max_momentum)
  1779. base_momentums = _format_param("base_momentum", optimizer, base_momentum)
  1780. if last_epoch == -1:
  1781. for m_momentum, b_momentum, group in zip(
  1782. max_momentums, base_momentums, optimizer.param_groups
  1783. ):
  1784. if self.use_beta1:
  1785. group["betas"] = (m_momentum, *group["betas"][1:])
  1786. else:
  1787. group["momentum"] = m_momentum
  1788. group["max_momentum"] = m_momentum
  1789. group["base_momentum"] = b_momentum
  1790. super().__init__(optimizer, last_epoch)
  1791. def _anneal_func(self, *args, **kwargs):
  1792. if hasattr(self, "_anneal_func_type"):
  1793. if self._anneal_func_type == "cos":
  1794. return self._annealing_cos(*args, **kwargs)
  1795. elif self._anneal_func_type == "linear":
  1796. return self._annealing_linear(*args, **kwargs)
  1797. else:
  1798. raise ValueError(f"Unknown _anneal_func_type: {self._anneal_func_type}")
  1799. else:
  1800. # For BC
  1801. return self.anneal_func(*args, **kwargs) # type: ignore[attr-defined]
  1802. @staticmethod
  1803. def _annealing_cos(start, end, pct):
  1804. """Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."""
  1805. cos_out = math.cos(math.pi * pct) + 1
  1806. return end + (start - end) / 2.0 * cos_out
  1807. @staticmethod
  1808. def _annealing_linear(start, end, pct):
  1809. """Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."""
  1810. return (end - start) * pct + start
  1811. @override
  1812. def get_lr(self) -> list[float]:
  1813. """Compute the learning rate of each parameter group."""
  1814. _warn_get_lr_called_within_step(self)
  1815. lrs = []
  1816. step_num = self.last_epoch
  1817. if step_num > self.total_steps:
  1818. raise ValueError(
  1819. f"Tried to step {step_num} times. The specified number of total steps is {self.total_steps}"
  1820. )
  1821. for group in self.optimizer.param_groups:
  1822. start_step = 0.0
  1823. for i, phase in enumerate(self._schedule_phases):
  1824. end_step = phase["end_step"]
  1825. if step_num <= end_step or i == len(self._schedule_phases) - 1:
  1826. pct = (step_num - start_step) / (end_step - start_step)
  1827. computed_lr = self._anneal_func(
  1828. group[phase["start_lr"]], group[phase["end_lr"]], pct
  1829. )
  1830. if self.cycle_momentum:
  1831. computed_momentum = self._anneal_func(
  1832. group[phase["start_momentum"]],
  1833. group[phase["end_momentum"]],
  1834. pct,
  1835. )
  1836. break
  1837. start_step = phase["end_step"]
  1838. lrs.append(computed_lr) # type: ignore[possibly-undefined]
  1839. if self.cycle_momentum:
  1840. if self.use_beta1:
  1841. group["betas"] = (computed_momentum, *group["betas"][1:]) # type: ignore[possibly-undefined]
  1842. else:
  1843. group["momentum"] = computed_momentum # type: ignore[possibly-undefined]
  1844. return lrs