scheduler.py 4.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # We ignore warnings about stepping the scheduler since we step it ourselves during gradient accumulation
  15. import warnings
  16. from .state import AcceleratorState, GradientState
  17. warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler")
  18. class AcceleratedScheduler:
  19. """
  20. A wrapper around a learning rate scheduler that will only step when the optimizer(s) have a training step. Useful
  21. to avoid making a scheduler step too fast when gradients went overflow and there was no training step (in mixed
  22. precision training)
  23. When performing gradient accumulation scheduler lengths should not be changed accordingly, Accelerate will always
  24. step the scheduler to account for it.
  25. Args:
  26. scheduler (`torch.optim.lr_scheduler._LRScheduler`):
  27. The scheduler to wrap.
  28. optimizers (one or a list of `torch.optim.Optimizer`):
  29. The optimizers used.
  30. step_with_optimizer (`bool`, *optional*, defaults to `True`):
  31. Whether or not the scheduler should be stepped at each optimizer step.
  32. split_batches (`bool`, *optional*, defaults to `False`):
  33. Whether or not the dataloaders split one batch across the different processes (so batch size is the same
  34. regardless of the number of processes) or create batches on each process (so batch size is the original
  35. batch size multiplied by the number of processes).
  36. """
  37. def __init__(self, scheduler, optimizers, step_with_optimizer: bool = True, split_batches: bool = False):
  38. self.scheduler = scheduler
  39. self.optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers]
  40. self.split_batches = split_batches
  41. self.step_with_optimizer = step_with_optimizer
  42. self.gradient_state = GradientState()
  43. def step(self, *args, **kwargs):
  44. if not self.step_with_optimizer:
  45. # No link between scheduler and optimizer -> just step
  46. self.scheduler.step(*args, **kwargs)
  47. return
  48. # Otherwise, first make sure the optimizer was stepped.
  49. if not self.gradient_state.sync_gradients:
  50. if self.gradient_state.adjust_scheduler:
  51. self.scheduler._step_count += 1
  52. return
  53. for opt in self.optimizers:
  54. if opt.step_was_skipped:
  55. return
  56. if self.split_batches:
  57. # Split batches -> the training dataloader batch size is not changed so one step per training step
  58. self.scheduler.step(*args, **kwargs)
  59. else:
  60. # Otherwise the training dataloader batch size was multiplied by `num_processes`, so we need to do
  61. # num_processes steps per training step
  62. num_processes = AcceleratorState().num_processes
  63. for _ in range(num_processes):
  64. # Special case when using OneCycle and `drop_last` was not used
  65. if hasattr(self.scheduler, "total_steps"):
  66. if self.scheduler._step_count <= self.scheduler.total_steps:
  67. self.scheduler.step(*args, **kwargs)
  68. else:
  69. self.scheduler.step(*args, **kwargs)
  70. # Passthroughs
  71. def get_last_lr(self):
  72. return self.scheduler.get_last_lr()
  73. def state_dict(self):
  74. return self.scheduler.state_dict()
  75. def load_state_dict(self, state_dict):
  76. self.scheduler.load_state_dict(state_dict)
  77. def get_lr(self):
  78. return self.scheduler.get_lr()
  79. def print_lr(self, *args, **kwargs):
  80. return self.scheduler.print_lr(*args, **kwargs)