optimization.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973
  1. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch optimization for BERT model."""
  15. import math
  16. import warnings
  17. from functools import partial
  18. from typing import Optional, Union
  19. import torch
  20. from torch.optim import Optimizer
  21. from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau
  22. from .trainer_pt_utils import LayerWiseDummyOptimizer, LayerWiseDummyScheduler
  23. from .trainer_utils import SchedulerType
  24. from .utils import logging
  25. logger = logging.get_logger(__name__)
  26. def _get_constant_lambda(_=None):
  27. return 1
  28. def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
  29. """
  30. Create a schedule with a constant learning rate, using the learning rate set in optimizer.
  31. Args:
  32. optimizer ([`~torch.optim.Optimizer`]):
  33. The optimizer for which to schedule the learning rate.
  34. last_epoch (`int`, *optional*, defaults to -1):
  35. The index of the last epoch when resuming training.
  36. Return:
  37. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  38. """
  39. return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch)
  40. def get_reduce_on_plateau_schedule(optimizer: Optimizer, **kwargs):
  41. """
  42. Create a schedule with a constant learning rate that decreases when a metric has stopped improving.
  43. Args:
  44. optimizer ([`~torch.optim.Optimizer`]):
  45. The optimizer for which to schedule the learning rate.
  46. kwargs (`dict`, *optional*):
  47. Extra parameters to be passed to the scheduler. See `torch.optim.lr_scheduler.ReduceLROnPlateau`
  48. for possible parameters.
  49. Return:
  50. `torch.optim.lr_scheduler.ReduceLROnPlateau` with the appropriate schedule.
  51. """
  52. return ReduceLROnPlateau(optimizer, **kwargs)
  53. def _get_constant_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int):
  54. if current_step < num_warmup_steps:
  55. return float(current_step) / float(max(1.0, num_warmup_steps))
  56. return 1.0
  57. def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
  58. """
  59. Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
  60. increases linearly between 0 and the initial lr set in the optimizer.
  61. Args:
  62. optimizer ([`~torch.optim.Optimizer`]):
  63. The optimizer for which to schedule the learning rate.
  64. num_warmup_steps (`int`):
  65. The number of steps for the warmup phase.
  66. last_epoch (`int`, *optional*, defaults to -1):
  67. The index of the last epoch when resuming training.
  68. Return:
  69. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  70. """
  71. lr_lambda = partial(_get_constant_schedule_with_warmup_lr_lambda, num_warmup_steps=num_warmup_steps)
  72. return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
  73. def _get_linear_schedule_with_warmup_lr_lambda(current_step: int, *, num_warmup_steps: int, num_training_steps: int):
  74. if current_step < num_warmup_steps:
  75. return float(current_step) / float(max(1, num_warmup_steps))
  76. return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
  77. def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
  78. """
  79. Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
  80. a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
  81. Args:
  82. optimizer ([`~torch.optim.Optimizer`]):
  83. The optimizer for which to schedule the learning rate.
  84. num_warmup_steps (`int`):
  85. The number of steps for the warmup phase.
  86. num_training_steps (`int`):
  87. The total number of training steps.
  88. last_epoch (`int`, *optional*, defaults to -1):
  89. The index of the last epoch when resuming training.
  90. Return:
  91. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  92. """
  93. lr_lambda = partial(
  94. _get_linear_schedule_with_warmup_lr_lambda,
  95. num_warmup_steps=num_warmup_steps,
  96. num_training_steps=num_training_steps,
  97. )
  98. return LambdaLR(optimizer, lr_lambda, last_epoch)
  99. def _get_cosine_schedule_with_warmup_lr_lambda(
  100. current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float
  101. ):
  102. if current_step < num_warmup_steps:
  103. return float(current_step) / float(max(1, num_warmup_steps))
  104. progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
  105. return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
  106. def get_cosine_schedule_with_warmup(
  107. optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
  108. ):
  109. """
  110. Create a schedule with a learning rate that decreases following the values of the cosine function between the
  111. initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
  112. initial lr set in the optimizer.
  113. Args:
  114. optimizer ([`~torch.optim.Optimizer`]):
  115. The optimizer for which to schedule the learning rate.
  116. num_warmup_steps (`int`):
  117. The number of steps for the warmup phase.
  118. num_training_steps (`int`):
  119. The total number of training steps.
  120. num_cycles (`float`, *optional*, defaults to 0.5):
  121. The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
  122. following a half-cosine).
  123. last_epoch (`int`, *optional*, defaults to -1):
  124. The index of the last epoch when resuming training.
  125. Return:
  126. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  127. """
  128. lr_lambda = partial(
  129. _get_cosine_schedule_with_warmup_lr_lambda,
  130. num_warmup_steps=num_warmup_steps,
  131. num_training_steps=num_training_steps,
  132. num_cycles=num_cycles,
  133. )
  134. return LambdaLR(optimizer, lr_lambda, last_epoch)
  135. def _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda(
  136. current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: int
  137. ):
  138. if current_step < num_warmup_steps:
  139. return float(current_step) / float(max(1, num_warmup_steps))
  140. progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
  141. if progress >= 1.0:
  142. return 0.0
  143. return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
  144. def get_cosine_with_hard_restarts_schedule_with_warmup(
  145. optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
  146. ):
  147. """
  148. Create a schedule with a learning rate that decreases following the values of the cosine function between the
  149. initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
  150. linearly between 0 and the initial lr set in the optimizer.
  151. Args:
  152. optimizer ([`~torch.optim.Optimizer`]):
  153. The optimizer for which to schedule the learning rate.
  154. num_warmup_steps (`int`):
  155. The number of steps for the warmup phase.
  156. num_training_steps (`int`):
  157. The total number of training steps.
  158. num_cycles (`int`, *optional*, defaults to 1):
  159. The number of hard restarts to use.
  160. last_epoch (`int`, *optional*, defaults to -1):
  161. The index of the last epoch when resuming training.
  162. Return:
  163. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  164. """
  165. lr_lambda = partial(
  166. _get_cosine_with_hard_restarts_schedule_with_warmup_lr_lambda,
  167. num_warmup_steps=num_warmup_steps,
  168. num_training_steps=num_training_steps,
  169. num_cycles=num_cycles,
  170. )
  171. return LambdaLR(optimizer, lr_lambda, last_epoch)
  172. def _get_polynomial_decay_schedule_with_warmup_lr_lambda(
  173. current_step: int,
  174. *,
  175. num_warmup_steps: int,
  176. num_training_steps: int,
  177. lr_end: float,
  178. power: float,
  179. lr_init: int,
  180. ):
  181. if current_step < num_warmup_steps:
  182. return float(current_step) / float(max(1, num_warmup_steps))
  183. elif current_step > num_training_steps:
  184. return lr_end / lr_init # as LambdaLR multiplies by lr_init
  185. else:
  186. lr_range = lr_init - lr_end
  187. decay_steps = num_training_steps - num_warmup_steps
  188. pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
  189. decay = lr_range * pct_remaining**power + lr_end
  190. return decay / lr_init # as LambdaLR multiplies by lr_init
  191. def get_polynomial_decay_schedule_with_warmup(
  192. optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
  193. ):
  194. """
  195. Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
  196. optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
  197. initial lr set in the optimizer.
  198. Args:
  199. optimizer ([`~torch.optim.Optimizer`]):
  200. The optimizer for which to schedule the learning rate.
  201. num_warmup_steps (`int`):
  202. The number of steps for the warmup phase.
  203. num_training_steps (`int`):
  204. The total number of training steps.
  205. lr_end (`float`, *optional*, defaults to 1e-7):
  206. The end LR.
  207. power (`float`, *optional*, defaults to 1.0):
  208. Power factor.
  209. last_epoch (`int`, *optional*, defaults to -1):
  210. The index of the last epoch when resuming training.
  211. Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
  212. implementation at
  213. https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
  214. Return:
  215. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  216. """
  217. lr_init = optimizer.defaults["lr"]
  218. if not (lr_init > lr_end):
  219. raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")
  220. lr_lambda = partial(
  221. _get_polynomial_decay_schedule_with_warmup_lr_lambda,
  222. num_warmup_steps=num_warmup_steps,
  223. num_training_steps=num_training_steps,
  224. lr_end=lr_end,
  225. power=power,
  226. lr_init=lr_init,
  227. )
  228. return LambdaLR(optimizer, lr_lambda, last_epoch)
  229. def _get_inverse_sqrt_schedule_lr_lambda(current_step: int, *, num_warmup_steps: int, timescale: Optional[int] = None):
  230. if current_step < num_warmup_steps:
  231. return float(current_step) / float(max(1, num_warmup_steps))
  232. shift = timescale - num_warmup_steps
  233. decay = 1.0 / math.sqrt((current_step + shift) / timescale)
  234. return decay
  235. def get_inverse_sqrt_schedule(
  236. optimizer: Optimizer, num_warmup_steps: int, timescale: Optional[int] = None, last_epoch: int = -1
  237. ):
  238. """
  239. Create a schedule with an inverse square-root learning rate, from the initial lr set in the optimizer, after a
  240. warmup period which increases lr linearly from 0 to the initial lr set in the optimizer.
  241. Args:
  242. optimizer ([`~torch.optim.Optimizer`]):
  243. The optimizer for which to schedule the learning rate.
  244. num_warmup_steps (`int`):
  245. The number of steps for the warmup phase.
  246. timescale (`int`, *optional*, defaults to `num_warmup_steps`):
  247. Time scale.
  248. last_epoch (`int`, *optional*, defaults to -1):
  249. The index of the last epoch when resuming training.
  250. Return:
  251. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  252. """
  253. # Note: this implementation is adapted from
  254. # https://github.com/google-research/big_vision/blob/f071ce68852d56099437004fd70057597a95f6ef/big_vision/utils.py#L930
  255. if timescale is None:
  256. timescale = num_warmup_steps or 10_000
  257. lr_lambda = partial(_get_inverse_sqrt_schedule_lr_lambda, num_warmup_steps=num_warmup_steps, timescale=timescale)
  258. return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
  259. def _get_cosine_schedule_with_warmup_lr_lambda(
  260. current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0
  261. ):
  262. if current_step < num_warmup_steps:
  263. return float(current_step) / float(max(1, num_warmup_steps))
  264. progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
  265. factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
  266. factor = factor * (1 - min_lr_rate) + min_lr_rate
  267. return max(0, factor)
  268. def get_cosine_with_min_lr_schedule_with_warmup(
  269. optimizer: Optimizer,
  270. num_warmup_steps: int,
  271. num_training_steps: int,
  272. num_cycles: float = 0.5,
  273. last_epoch: int = -1,
  274. min_lr: Optional[float] = None,
  275. min_lr_rate: Optional[float] = None,
  276. ):
  277. """
  278. Create a schedule with a learning rate that decreases following the values of the cosine function between the
  279. initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the
  280. initial lr set in the optimizer.
  281. Args:
  282. optimizer ([`~torch.optim.Optimizer`]):
  283. The optimizer for which to schedule the learning rate.
  284. num_warmup_steps (`int`):
  285. The number of steps for the warmup phase.
  286. num_training_steps (`int`):
  287. The total number of training steps.
  288. num_cycles (`float`, *optional*, defaults to 0.5):
  289. The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
  290. following a half-cosine).
  291. last_epoch (`int`, *optional*, defaults to -1):
  292. The index of the last epoch when resuming training.
  293. min_lr (`float`, *optional*):
  294. The minimum learning rate to reach after the cosine schedule.
  295. min_lr_rate (`float`, *optional*):
  296. The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set.
  297. Return:
  298. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  299. """
  300. if min_lr is not None and min_lr_rate is not None:
  301. raise ValueError("Only one of min_lr or min_lr_rate should be set")
  302. elif min_lr is not None:
  303. min_lr_rate = min_lr / optimizer.defaults["lr"]
  304. elif min_lr_rate is None:
  305. raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`")
  306. lr_lambda = partial(
  307. _get_cosine_schedule_with_warmup_lr_lambda,
  308. num_warmup_steps=num_warmup_steps,
  309. num_training_steps=num_training_steps,
  310. num_cycles=num_cycles,
  311. min_lr_rate=min_lr_rate,
  312. )
  313. return LambdaLR(optimizer, lr_lambda, last_epoch)
  314. def _get_cosine_with_min_lr_schedule_with_warmup_lr_rate_lambda(
  315. current_step: int,
  316. *,
  317. num_warmup_steps: int,
  318. num_training_steps: int,
  319. num_cycles: float,
  320. min_lr_rate: float = 0.0,
  321. warmup_lr_rate: Optional[float] = None,
  322. ):
  323. current_step = float(current_step)
  324. num_warmup_steps = float(num_warmup_steps)
  325. num_training_steps = float(num_training_steps)
  326. if current_step < num_warmup_steps:
  327. if warmup_lr_rate is None:
  328. return (current_step + 1.0) / max(1.0, num_warmup_steps)
  329. else:
  330. warmup_lr_rate = float(warmup_lr_rate)
  331. return warmup_lr_rate + (1.0 - warmup_lr_rate) * (current_step) / (max(1, num_warmup_steps - 1))
  332. progress = (current_step - num_warmup_steps + 1.0) / (max(1.0, num_training_steps - num_warmup_steps))
  333. factor = 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress))
  334. factor = factor * (1 - min_lr_rate) + min_lr_rate
  335. return max(0, factor)
  336. def get_cosine_with_min_lr_schedule_with_warmup_lr_rate(
  337. optimizer: Optimizer,
  338. num_warmup_steps: int,
  339. num_training_steps: int,
  340. num_cycles: float = 0.5,
  341. last_epoch: int = -1,
  342. min_lr: Optional[float] = None,
  343. min_lr_rate: Optional[float] = None,
  344. warmup_lr_rate: Optional[float] = None,
  345. ):
  346. """
  347. Create a schedule with a learning rate that decreases following the values of the cosine function between the
  348. initial lr set in the optimizer to min_lr, after a warmup period during which it increases linearly between 0 and the
  349. initial lr set in the optimizer.
  350. Args:
  351. optimizer ([`~torch.optim.Optimizer`]):
  352. The optimizer for which to schedule the learning rate.
  353. num_warmup_steps (`int`):
  354. The number of steps for the warmup phase.
  355. num_training_steps (`int`):
  356. The total number of training steps.
  357. num_cycles (`float`, *optional*, defaults to 0.5):
  358. The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
  359. following a half-cosine).
  360. last_epoch (`int`, *optional*, defaults to -1):
  361. The index of the last epoch when resuming training.
  362. min_lr (`float`, *optional*):
  363. The minimum learning rate to reach after the cosine schedule.
  364. min_lr_rate (`float`, *optional*):
  365. The minimum learning rate as a ratio of the initial learning rate. If set, `min_lr` should not be set.
  366. warmup_lr_rate (`float`, *optional*):
  367. The minimum learning rate as a ratio of the start learning rate. If not set, `warmup_lr_rate` will be treated as float(1/num_warmup_steps).
  368. Return:
  369. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  370. """
  371. if min_lr is not None and min_lr_rate is not None:
  372. raise ValueError("Only one of min_lr or min_lr_rate should be set")
  373. elif min_lr is not None:
  374. min_lr_rate = min_lr / optimizer.defaults["lr"]
  375. elif min_lr_rate is None:
  376. raise ValueError("One of min_lr or min_lr_rate should be set through the `lr_scheduler_kwargs`")
  377. lr_lambda = partial(
  378. _get_cosine_with_min_lr_schedule_with_warmup_lr_rate_lambda,
  379. num_warmup_steps=num_warmup_steps,
  380. num_training_steps=num_training_steps,
  381. num_cycles=num_cycles,
  382. min_lr_rate=min_lr_rate,
  383. warmup_lr_rate=warmup_lr_rate,
  384. )
  385. return LambdaLR(optimizer, lr_lambda, last_epoch)
  386. def _get_wsd_scheduler_lambda(
  387. current_step: int,
  388. *,
  389. num_warmup_steps: int,
  390. num_stable_steps: int,
  391. num_decay_steps: int,
  392. warmup_type: str,
  393. decay_type: str,
  394. min_lr_ratio: float,
  395. num_cycles: float,
  396. ):
  397. if current_step < num_warmup_steps:
  398. progress = float(current_step) / float(max(1, num_warmup_steps))
  399. if warmup_type == "linear":
  400. factor = progress
  401. elif warmup_type == "cosine":
  402. factor = 0.5 * (1.0 - math.cos(math.pi * progress))
  403. elif warmup_type == "1-sqrt":
  404. factor = 1.0 - math.sqrt(1.0 - progress)
  405. factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio
  406. return max(0.0, factor)
  407. if current_step < num_warmup_steps + num_stable_steps:
  408. return 1.0
  409. if current_step < num_warmup_steps + num_stable_steps + num_decay_steps:
  410. progress = float(current_step - num_warmup_steps - num_stable_steps) / float(max(1, num_decay_steps))
  411. if decay_type == "linear":
  412. factor = 1.0 - progress
  413. elif decay_type == "cosine":
  414. factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
  415. elif decay_type == "1-sqrt":
  416. factor = 1.0 - math.sqrt(progress)
  417. factor = factor * (1.0 - min_lr_ratio) + min_lr_ratio
  418. return max(0.0, factor)
  419. return min_lr_ratio
  420. def get_wsd_schedule(
  421. optimizer: Optimizer,
  422. num_warmup_steps: int,
  423. num_decay_steps: int,
  424. num_training_steps: Optional[int] = None,
  425. num_stable_steps: Optional[int] = None,
  426. warmup_type: str = "linear",
  427. decay_type: str = "cosine",
  428. min_lr_ratio: float = 0,
  429. num_cycles: float = 0.5,
  430. last_epoch: int = -1,
  431. ):
  432. """
  433. Create a schedule with a learning rate that has three stages:
  434. 1. warmup: increase from min_lr_ratio times the initial learning rate to the initial learning rate following a warmup_type.
  435. 2. stable: constant learning rate.
  436. 3. decay: decrease from the initial learning rate to min_lr_ratio times the initial learning rate following a decay_type.
  437. Args:
  438. optimizer ([`~torch.optim.Optimizer`]):
  439. The optimizer for which to schedule the learning rate.
  440. num_warmup_steps (`int`):
  441. The number of steps for the warmup phase.
  442. num_decay_steps (`int`):
  443. The number of steps for the decay phase.
  444. num_training_steps (`int`, *optional*):
  445. The total number of training steps. This is the sum of the warmup, stable and decay steps. If `num_stable_steps` is not provided, the stable phase will be `num_training_steps - num_warmup_steps - num_decay_steps`.
  446. num_stable_steps (`int`, *optional*):
  447. The number of steps for the stable phase. Please ensure that `num_warmup_steps + num_stable_steps + num_decay_steps` equals `num_training_steps`, otherwise the other steps will default to the minimum learning rate.
  448. warmup_type (`str`, *optional*, defaults to "linear"):
  449. The type of warmup to use. Can be 'linear', 'cosine' or '1-sqrt'.
  450. decay_type (`str`, *optional*, defaults to "cosine"):
  451. The type of decay to use. Can be 'linear', 'cosine' or '1-sqrt'.
  452. min_lr_ratio (`float`, *optional*, defaults to 0):
  453. The minimum learning rate as a ratio of the initial learning rate.
  454. num_cycles (`float`, *optional*, defaults to 0.5):
  455. The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
  456. following a half-cosine).
  457. last_epoch (`int`, *optional*, defaults to -1):
  458. The index of the last epoch when resuming training.
  459. Return:
  460. `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
  461. """
  462. if num_training_steps is None and num_stable_steps is None:
  463. raise ValueError("Either num_training_steps or num_stable_steps must be specified.")
  464. if num_training_steps is not None and num_stable_steps is not None:
  465. warnings.warn("Both num_training_steps and num_stable_steps are specified. num_stable_steps will be used.")
  466. if warmup_type not in ["linear", "cosine", "1-sqrt"]:
  467. raise ValueError(f"Unknown warmup type: {warmup_type}, expected 'linear', 'cosine' or '1-sqrt'")
  468. if decay_type not in ["linear", "cosine", "1-sqrt"]:
  469. raise ValueError(f"Unknown decay type: {decay_type}, expected 'linear', 'cosine' or '1-sqrt'")
  470. if num_stable_steps is None:
  471. num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps
  472. lr_lambda = partial(
  473. _get_wsd_scheduler_lambda,
  474. num_warmup_steps=num_warmup_steps,
  475. num_stable_steps=num_stable_steps,
  476. num_decay_steps=num_decay_steps,
  477. warmup_type=warmup_type,
  478. decay_type=decay_type,
  479. min_lr_ratio=min_lr_ratio,
  480. num_cycles=num_cycles,
  481. )
  482. return LambdaLR(optimizer, lr_lambda, last_epoch)
  483. TYPE_TO_SCHEDULER_FUNCTION = {
  484. SchedulerType.LINEAR: get_linear_schedule_with_warmup,
  485. SchedulerType.COSINE: get_cosine_schedule_with_warmup,
  486. SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
  487. SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
  488. SchedulerType.CONSTANT: get_constant_schedule,
  489. SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
  490. SchedulerType.INVERSE_SQRT: get_inverse_sqrt_schedule,
  491. SchedulerType.REDUCE_ON_PLATEAU: get_reduce_on_plateau_schedule,
  492. SchedulerType.COSINE_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup,
  493. SchedulerType.COSINE_WARMUP_WITH_MIN_LR: get_cosine_with_min_lr_schedule_with_warmup_lr_rate,
  494. SchedulerType.WARMUP_STABLE_DECAY: get_wsd_schedule,
  495. }
  496. def get_scheduler(
  497. name: Union[str, SchedulerType],
  498. optimizer: Optimizer,
  499. num_warmup_steps: Optional[int] = None,
  500. num_training_steps: Optional[int] = None,
  501. scheduler_specific_kwargs: Optional[dict] = None,
  502. ):
  503. """
  504. Unified API to get any scheduler from its name.
  505. Args:
  506. name (`str` or `SchedulerType`):
  507. The name of the scheduler to use.
  508. optimizer (`torch.optim.Optimizer`):
  509. The optimizer that will be used during training.
  510. num_warmup_steps (`int`, *optional*):
  511. The number of warmup steps to do. This is not required by all schedulers (hence the argument being
  512. optional), the function will raise an error if it's unset and the scheduler type requires it.
  513. num_training_steps (`int``, *optional*):
  514. The number of training steps to do. This is not required by all schedulers (hence the argument being
  515. optional), the function will raise an error if it's unset and the scheduler type requires it.
  516. scheduler_specific_kwargs (`dict`, *optional*):
  517. Extra parameters for schedulers such as cosine with restarts. Mismatched scheduler types and scheduler
  518. parameters will cause the scheduler function to raise a TypeError.
  519. """
  520. name = SchedulerType(name)
  521. schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
  522. # If a `LayerWiseDummyOptimizer` is passed we extract the optimizer dict and
  523. # recursively call `get_scheduler` to get the proper schedulers on each parameter
  524. if optimizer is not None and isinstance(optimizer, LayerWiseDummyOptimizer):
  525. optimizer_dict = optimizer.optimizer_dict
  526. scheduler_dict = {}
  527. for param in optimizer_dict:
  528. scheduler_dict[param] = get_scheduler(
  529. name,
  530. optimizer=optimizer_dict[param],
  531. num_warmup_steps=num_warmup_steps,
  532. num_training_steps=num_training_steps,
  533. scheduler_specific_kwargs=scheduler_specific_kwargs,
  534. )
  535. def scheduler_hook(param):
  536. # Since the optimizer hook has been already attached we only need to
  537. # attach the scheduler hook, the gradients have been zeroed here
  538. scheduler_dict[param].step()
  539. for param in optimizer_dict:
  540. if param.requires_grad:
  541. param.register_post_accumulate_grad_hook(scheduler_hook)
  542. return LayerWiseDummyScheduler(optimizer_dict=optimizer_dict, lr=optimizer.defaults["lr"])
  543. if name == SchedulerType.CONSTANT:
  544. return schedule_func(optimizer)
  545. if scheduler_specific_kwargs is None:
  546. scheduler_specific_kwargs = {}
  547. if name == SchedulerType.REDUCE_ON_PLATEAU:
  548. return schedule_func(optimizer, **scheduler_specific_kwargs)
  549. # All other schedulers require `num_warmup_steps`
  550. if num_warmup_steps is None:
  551. raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
  552. if name == SchedulerType.CONSTANT_WITH_WARMUP:
  553. return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
  554. if name == SchedulerType.INVERSE_SQRT:
  555. return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
  556. # wsd scheduler requires either num_training_steps or num_stable_steps
  557. if name == SchedulerType.WARMUP_STABLE_DECAY:
  558. return schedule_func(
  559. optimizer,
  560. num_warmup_steps=num_warmup_steps,
  561. num_training_steps=num_training_steps,
  562. **scheduler_specific_kwargs,
  563. )
  564. # All other schedulers require `num_training_steps`
  565. if num_training_steps is None:
  566. raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
  567. return schedule_func(
  568. optimizer,
  569. num_warmup_steps=num_warmup_steps,
  570. num_training_steps=num_training_steps,
  571. **scheduler_specific_kwargs,
  572. )
  573. class Adafactor(Optimizer):
  574. """
  575. AdaFactor pytorch implementation can be used as a drop in replacement for Adam original fairseq code:
  576. https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
  577. Paper: *Adafactor: Adaptive Learning Rates with Sublinear Memory Cost* https://huggingface.co/papers/1804.04235 Note that
  578. this optimizer internally adjusts the learning rate depending on the `scale_parameter`, `relative_step` and
  579. `warmup_init` options. To use a manual (external) learning rate schedule you should set `scale_parameter=False` and
  580. `relative_step=False`.
  581. Arguments:
  582. params (`Iterable[nn.parameter.Parameter]`):
  583. Iterable of parameters to optimize or dictionaries defining parameter groups.
  584. lr (`float`, *optional*):
  585. The external learning rate.
  586. eps (`tuple[float, float]`, *optional*, defaults to `(1e-30, 0.001)`):
  587. Regularization constants for square gradient and parameter scale respectively
  588. clip_threshold (`float`, *optional*, defaults to 1.0):
  589. Threshold of root mean square of final gradient update
  590. decay_rate (`float`, *optional*, defaults to -0.8):
  591. Coefficient used to compute running averages of square
  592. beta1 (`float`, *optional*):
  593. Coefficient used for computing running averages of gradient
  594. weight_decay (`float`, *optional*, defaults to 0.0):
  595. Weight decay (L2 penalty)
  596. scale_parameter (`bool`, *optional*, defaults to `True`):
  597. If True, learning rate is scaled by root mean square
  598. relative_step (`bool`, *optional*, defaults to `True`):
  599. If True, time-dependent learning rate is computed instead of external learning rate
  600. warmup_init (`bool`, *optional*, defaults to `False`):
  601. Time-dependent learning rate computation depends on whether warm-up initialization is being used
  602. This implementation handles low-precision (FP16, bfloat) values, but we have not thoroughly tested.
  603. Recommended T5 finetuning settings (https://discuss.huggingface.co/t/t5-finetuning-tips/684/3):
  604. - Training without LR warmup or clip_threshold is not recommended.
  605. - use scheduled LR warm-up to fixed LR
  606. - use clip_threshold=1.0 (https://huggingface.co/papers/1804.04235)
  607. - Disable relative updates
  608. - Use scale_parameter=False
  609. - Additional optimizer operations like gradient clipping should not be used alongside Adafactor
  610. Example:
  611. ```python
  612. Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=1e-3)
  613. ```
  614. Others reported the following combination to work well:
  615. ```python
  616. Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
  617. ```
  618. When using `lr=None` with [`Trainer`] you will most likely need to use [`~optimization.AdafactorSchedule`]
  619. scheduler as following:
  620. ```python
  621. from transformers.optimization import Adafactor, AdafactorSchedule
  622. optimizer = Adafactor(model.parameters(), scale_parameter=True, relative_step=True, warmup_init=True, lr=None)
  623. lr_scheduler = AdafactorSchedule(optimizer)
  624. trainer = Trainer(..., optimizers=(optimizer, lr_scheduler))
  625. ```
  626. Usage:
  627. ```python
  628. # replace AdamW with Adafactor
  629. optimizer = Adafactor(
  630. model.parameters(),
  631. lr=1e-3,
  632. eps=(1e-30, 1e-3),
  633. clip_threshold=1.0,
  634. decay_rate=-0.8,
  635. beta1=None,
  636. weight_decay=0.0,
  637. relative_step=False,
  638. scale_parameter=False,
  639. warmup_init=False,
  640. )
  641. ```"""
  642. def __init__(
  643. self,
  644. params,
  645. lr=None,
  646. eps=(1e-30, 1e-3),
  647. clip_threshold=1.0,
  648. decay_rate=-0.8,
  649. beta1=None,
  650. weight_decay=0.0,
  651. scale_parameter=True,
  652. relative_step=True,
  653. warmup_init=False,
  654. ):
  655. if lr is not None and relative_step:
  656. raise ValueError("Cannot combine manual `lr` and `relative_step=True` options")
  657. if warmup_init and not relative_step:
  658. raise ValueError("`warmup_init=True` requires `relative_step=True`")
  659. defaults = {
  660. "lr": lr,
  661. "eps": eps,
  662. "clip_threshold": clip_threshold,
  663. "decay_rate": decay_rate,
  664. "beta1": beta1,
  665. "weight_decay": weight_decay,
  666. "scale_parameter": scale_parameter,
  667. "relative_step": relative_step,
  668. "warmup_init": warmup_init,
  669. }
  670. super().__init__(params, defaults)
  671. @staticmethod
  672. def _get_lr(param_group, param_state):
  673. rel_step_sz = param_group["lr"]
  674. if param_group["relative_step"]:
  675. min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2
  676. rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"]))
  677. param_scale = 1.0
  678. if param_group["scale_parameter"]:
  679. param_scale = max(param_group["eps"][1], param_state["RMS"])
  680. return param_scale * rel_step_sz
  681. @staticmethod
  682. def _get_options(param_group, param_shape):
  683. factored = len(param_shape) >= 2
  684. use_first_moment = param_group["beta1"] is not None
  685. return factored, use_first_moment
  686. @staticmethod
  687. def _rms(tensor):
  688. return tensor.norm(2) / (tensor.numel() ** 0.5)
  689. @staticmethod
  690. def _approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col):
  691. # copy from fairseq's adafactor implementation:
  692. # https://github.com/huggingface/transformers/blob/8395f14de6068012787d83989c3627c3df6a252b/src/transformers/optimization.py#L505
  693. r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
  694. c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
  695. return torch.mul(r_factor, c_factor)
  696. @torch.no_grad()
  697. def step(self, closure=None):
  698. """
  699. Performs a single optimization step
  700. Arguments:
  701. closure (callable, optional): A closure that reevaluates the model
  702. and returns the loss.
  703. """
  704. loss = None
  705. if closure is not None:
  706. loss = closure()
  707. for group in self.param_groups:
  708. for p in group["params"]:
  709. if p.grad is None:
  710. continue
  711. grad = p.grad
  712. if grad.dtype in {torch.float16, torch.bfloat16}:
  713. grad = grad.float()
  714. if grad.is_sparse:
  715. raise RuntimeError("Adafactor does not support sparse gradients.")
  716. state = self.state[p]
  717. grad_shape = grad.shape
  718. factored, use_first_moment = self._get_options(group, grad_shape)
  719. # State Initialization
  720. if len(state) == 0:
  721. state["step"] = 0
  722. if use_first_moment:
  723. # Exponential moving average of gradient values
  724. state["exp_avg"] = torch.zeros_like(grad)
  725. if factored:
  726. state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad)
  727. state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad)
  728. else:
  729. state["exp_avg_sq"] = torch.zeros_like(grad)
  730. state["RMS"] = 0
  731. else:
  732. if use_first_moment:
  733. state["exp_avg"] = state["exp_avg"].to(grad)
  734. if factored:
  735. state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad)
  736. state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad)
  737. else:
  738. state["exp_avg_sq"] = state["exp_avg_sq"].to(grad)
  739. p_data_fp32 = p
  740. if p.dtype in {torch.float16, torch.bfloat16}:
  741. p_data_fp32 = p_data_fp32.float()
  742. state["step"] += 1
  743. state["RMS"] = self._rms(p_data_fp32)
  744. lr = self._get_lr(group, state)
  745. beta2t = 1.0 - math.pow(state["step"], group["decay_rate"])
  746. update = (grad**2) + group["eps"][0]
  747. if factored:
  748. exp_avg_sq_row = state["exp_avg_sq_row"]
  749. exp_avg_sq_col = state["exp_avg_sq_col"]
  750. exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t))
  751. exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t))
  752. # Approximation of exponential moving average of square of gradient
  753. update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col)
  754. update.mul_(grad)
  755. else:
  756. exp_avg_sq = state["exp_avg_sq"]
  757. exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t))
  758. update = exp_avg_sq.rsqrt().mul_(grad)
  759. update.div_((self._rms(update) / group["clip_threshold"]).clamp_(min=1.0))
  760. update.mul_(lr)
  761. if use_first_moment:
  762. exp_avg = state["exp_avg"]
  763. exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"]))
  764. update = exp_avg
  765. if group["weight_decay"] != 0:
  766. p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr))
  767. p_data_fp32.add_(-update)
  768. if p.dtype in {torch.float16, torch.bfloat16}:
  769. p.copy_(p_data_fp32)
  770. return loss
  771. class AdafactorSchedule(LambdaLR):
  772. """
  773. Since [`~optimization.Adafactor`] performs its own scheduling, if the training loop relies on a scheduler (e.g.,
  774. for logging), this class creates a proxy object that retrieves the current lr values from the optimizer.
  775. It returns `initial_lr` during startup and the actual `lr` during stepping.
  776. """
  777. def __init__(self, optimizer, initial_lr=0.0):
  778. def lr_lambda(_):
  779. return initial_lr
  780. for group in optimizer.param_groups:
  781. group["initial_lr"] = initial_lr
  782. super().__init__(optimizer, lr_lambda)
  783. for group in optimizer.param_groups:
  784. del group["initial_lr"]
  785. def get_lr(self):
  786. opt = self.optimizer
  787. lrs = [
  788. opt._get_lr(group, opt.state[group["params"][0]])
  789. for group in opt.param_groups
  790. if group["params"][0].grad is not None
  791. ]
  792. if len(lrs) == 0:
  793. lrs = self.base_lrs # if called before stepping
  794. return lrs
  795. def get_adafactor_schedule(optimizer, initial_lr=0.0):
  796. """
  797. Get a proxy schedule for [`~optimization.Adafactor`]
  798. Args:
  799. optimizer ([`~torch.optim.Optimizer`]):
  800. The optimizer for which to schedule the learning rate.
  801. initial_lr (`float`, *optional*, defaults to 0.0):
  802. Initial lr
  803. Return:
  804. [`~optimization.Adafactor`] proxy schedule object.
  805. """
  806. return AdafactorSchedule(optimizer, initial_lr)