AnnealingLR.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch DataLoader for TFRecords"""
  15. import math
  16. import torch
  17. from torch.optim.lr_scheduler import _LRScheduler
  18. class AnnealingLR(_LRScheduler):
  19. """Anneals the learning rate from start to zero along a cosine curve."""
  20. DECAY_STYLES = ['linear', 'cosine', 'exponential', 'constant', 'None']
  21. def __init__(self,
  22. optimizer,
  23. start_lr,
  24. warmup_iter,
  25. num_iters,
  26. decay_style=None,
  27. last_iter=-1):
  28. self.optimizer = optimizer
  29. self.start_lr = start_lr
  30. self.warmup_iter = warmup_iter
  31. self._step_count = last_iter + 1
  32. self.end_iter = num_iters
  33. self.decay_style = decay_style.lower() if isinstance(decay_style,
  34. str) else None
  35. self.step(self._step_count)
  36. if torch.distributed.get_rank() == 0:
  37. print('learning rate decaying', decay_style)
  38. def get_lr(self):
  39. # https://openreview.net/pdf?id=BJYwwY9ll pg. 4
  40. if self.warmup_iter > 0 and self._step_count <= self.warmup_iter:
  41. return float(self.start_lr) * self._step_count / self.warmup_iter
  42. else:
  43. if self.decay_style == self.DECAY_STYLES[0]:
  44. return self.start_lr * ((
  45. self.end_iter - # noqa W504
  46. (self._step_count - self.warmup_iter)) / self.end_iter)
  47. elif self.decay_style == self.DECAY_STYLES[1]:
  48. return self.start_lr / 2.0 * (
  49. math.cos(math.pi * (self._step_count - self.warmup_iter)
  50. / self.end_iter) + 1)
  51. elif self.decay_style == self.DECAY_STYLES[2]:
  52. # TODO: implement exponential decay
  53. return self.start_lr
  54. else:
  55. return self.start_lr
  56. def step(self, step_num=None):
  57. if step_num is None:
  58. step_num = self._step_count + 1
  59. self._step_count = step_num
  60. new_lr = self.get_lr()
  61. for group in self.optimizer.param_groups:
  62. group['lr'] = new_lr
  63. def state_dict(self):
  64. sd = {
  65. 'start_lr': self.start_lr,
  66. 'warmup_iter': self.warmup_iter,
  67. '_step_count': self._step_count,
  68. 'decay_style': self.decay_style,
  69. 'end_iter': self.end_iter
  70. }
  71. return sd
  72. def load_state_dict(self, sd):
  73. self.start_lr = sd['start_lr']
  74. self.warmup_iter = sd['warmup_iter']
  75. self._step_count = sd['_step_count']
  76. self.end_iter = sd['end_iter']
  77. self.decay_style = sd['decay_style']
  78. self.step(self._step_count)