lr_scheduler.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from paddle.optimizer.lr import LRScheduler
  16. class CyclicalCosineDecay(LRScheduler):
  17. def __init__(
  18. self, learning_rate, T_max, cycle=1, last_epoch=-1, eta_min=0.0, verbose=False
  19. ):
  20. """
  21. Cyclical cosine learning rate decay
  22. A learning rate which can be referred in https://arxiv.org/pdf/2012.12645.pdf
  23. Args:
  24. learning rate(float): learning rate
  25. T_max(int): maximum epoch num
  26. cycle(int): period of the cosine decay
  27. last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
  28. eta_min(float): minimum learning rate during training
  29. verbose(bool): whether to print learning rate for each epoch
  30. """
  31. super(CyclicalCosineDecay, self).__init__(learning_rate, last_epoch, verbose)
  32. self.cycle = cycle
  33. self.eta_min = eta_min
  34. def get_lr(self):
  35. if self.last_epoch == 0:
  36. return self.base_lr
  37. reletive_epoch = self.last_epoch % self.cycle
  38. lr = self.eta_min + 0.5 * (self.base_lr - self.eta_min) * (
  39. 1 + math.cos(math.pi * reletive_epoch / self.cycle)
  40. )
  41. return lr
  42. class OneCycleDecay(LRScheduler):
  43. """
  44. One Cycle learning rate decay
  45. A learning rate which can be referred in https://arxiv.org/abs/1708.07120
  46. Code referred in https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#OneCycleLR
  47. """
  48. def __init__(
  49. self,
  50. max_lr,
  51. epochs=None,
  52. steps_per_epoch=None,
  53. pct_start=0.3,
  54. anneal_strategy="cos",
  55. div_factor=25.0,
  56. final_div_factor=1e4,
  57. three_phase=False,
  58. last_epoch=-1,
  59. verbose=False,
  60. ):
  61. # Validate total_steps
  62. if epochs <= 0 or not isinstance(epochs, int):
  63. raise ValueError(
  64. "Expected positive integer epochs, but got {}".format(epochs)
  65. )
  66. if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
  67. raise ValueError(
  68. "Expected positive integer steps_per_epoch, but got {}".format(
  69. steps_per_epoch
  70. )
  71. )
  72. self.total_steps = epochs * steps_per_epoch
  73. self.max_lr = max_lr
  74. self.initial_lr = self.max_lr / div_factor
  75. self.min_lr = self.initial_lr / final_div_factor
  76. if three_phase:
  77. self._schedule_phases = [
  78. {
  79. "end_step": float(pct_start * self.total_steps) - 1,
  80. "start_lr": self.initial_lr,
  81. "end_lr": self.max_lr,
  82. },
  83. {
  84. "end_step": float(2 * pct_start * self.total_steps) - 2,
  85. "start_lr": self.max_lr,
  86. "end_lr": self.initial_lr,
  87. },
  88. {
  89. "end_step": self.total_steps - 1,
  90. "start_lr": self.initial_lr,
  91. "end_lr": self.min_lr,
  92. },
  93. ]
  94. else:
  95. self._schedule_phases = [
  96. {
  97. "end_step": float(pct_start * self.total_steps) - 1,
  98. "start_lr": self.initial_lr,
  99. "end_lr": self.max_lr,
  100. },
  101. {
  102. "end_step": self.total_steps - 1,
  103. "start_lr": self.max_lr,
  104. "end_lr": self.min_lr,
  105. },
  106. ]
  107. # Validate pct_start
  108. if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
  109. raise ValueError(
  110. "Expected float between 0 and 1 pct_start, but got {}".format(pct_start)
  111. )
  112. # Validate anneal_strategy
  113. if anneal_strategy not in ["cos", "linear"]:
  114. raise ValueError(
  115. "anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(
  116. anneal_strategy
  117. )
  118. )
  119. elif anneal_strategy == "cos":
  120. self.anneal_func = self._annealing_cos
  121. elif anneal_strategy == "linear":
  122. self.anneal_func = self._annealing_linear
  123. super(OneCycleDecay, self).__init__(max_lr, last_epoch, verbose)
  124. def _annealing_cos(self, start, end, pct):
  125. "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
  126. cos_out = math.cos(math.pi * pct) + 1
  127. return end + (start - end) / 2.0 * cos_out
  128. def _annealing_linear(self, start, end, pct):
  129. "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
  130. return (end - start) * pct + start
  131. def get_lr(self):
  132. computed_lr = 0.0
  133. step_num = self.last_epoch
  134. if step_num > self.total_steps:
  135. raise ValueError(
  136. "Tried to step {} times. The specified number of total steps is {}".format(
  137. step_num + 1, self.total_steps
  138. )
  139. )
  140. start_step = 0
  141. for i, phase in enumerate(self._schedule_phases):
  142. end_step = phase["end_step"]
  143. if step_num <= end_step or i == len(self._schedule_phases) - 1:
  144. pct = (step_num - start_step) / (end_step - start_step)
  145. computed_lr = self.anneal_func(phase["start_lr"], phase["end_lr"], pct)
  146. break
  147. start_step = phase["end_step"]
  148. return computed_lr
  149. class TwoStepCosineDecay(LRScheduler):
  150. def __init__(
  151. self, learning_rate, T_max1, T_max2, eta_min=0, last_epoch=-1, verbose=False
  152. ):
  153. if not isinstance(T_max1, int):
  154. raise TypeError(
  155. "The type of 'T_max1' in 'CosineAnnealingDecay' must be 'int', but received %s."
  156. % type(T_max1)
  157. )
  158. if not isinstance(T_max2, int):
  159. raise TypeError(
  160. "The type of 'T_max2' in 'CosineAnnealingDecay' must be 'int', but received %s."
  161. % type(T_max2)
  162. )
  163. if not isinstance(eta_min, (float, int)):
  164. raise TypeError(
  165. "The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
  166. % type(eta_min)
  167. )
  168. assert T_max1 > 0 and isinstance(
  169. T_max1, int
  170. ), " 'T_max1' must be a positive integer."
  171. assert T_max2 > 0 and isinstance(
  172. T_max2, int
  173. ), " 'T_max1' must be a positive integer."
  174. self.T_max1 = T_max1
  175. self.T_max2 = T_max2
  176. self.eta_min = float(eta_min)
  177. super(TwoStepCosineDecay, self).__init__(learning_rate, last_epoch, verbose)
  178. def get_lr(self):
  179. if self.last_epoch <= self.T_max1:
  180. if self.last_epoch == 0:
  181. return self.base_lr
  182. elif (self.last_epoch - 1 - self.T_max1) % (2 * self.T_max1) == 0:
  183. return (
  184. self.last_lr
  185. + (self.base_lr - self.eta_min)
  186. * (1 - math.cos(math.pi / self.T_max1))
  187. / 2
  188. )
  189. return (1 + math.cos(math.pi * self.last_epoch / self.T_max1)) / (
  190. 1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max1)
  191. ) * (self.last_lr - self.eta_min) + self.eta_min
  192. else:
  193. if (self.last_epoch - 1 - self.T_max2) % (2 * self.T_max2) == 0:
  194. return (
  195. self.last_lr
  196. + (self.base_lr - self.eta_min)
  197. * (1 - math.cos(math.pi / self.T_max2))
  198. / 2
  199. )
  200. return (1 + math.cos(math.pi * self.last_epoch / self.T_max2)) / (
  201. 1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max2)
  202. ) * (self.last_lr - self.eta_min) + self.eta_min
  203. def _get_closed_form_lr(self):
  204. if self.last_epoch <= self.T_max1:
  205. return (
  206. self.eta_min
  207. + (self.base_lr - self.eta_min)
  208. * (1 + math.cos(math.pi * self.last_epoch / self.T_max1))
  209. / 2
  210. )
  211. else:
  212. return (
  213. self.eta_min
  214. + (self.base_lr - self.eta_min)
  215. * (1 + math.cos(math.pi * self.last_epoch / self.T_max2))
  216. / 2
  217. )