| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- # Copyright 2022 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # We ignore warnings about stepping the scheduler since we step it ourselves during gradient accumulation
- import warnings
- from .state import AcceleratorState, GradientState
- warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler")
- class AcceleratedScheduler:
- """
- A wrapper around a learning rate scheduler that will only step when the optimizer(s) have a training step. Useful
- to avoid making a scheduler step too fast when gradients went overflow and there was no training step (in mixed
- precision training)
- When performing gradient accumulation scheduler lengths should not be changed accordingly, Accelerate will always
- step the scheduler to account for it.
- Args:
- scheduler (`torch.optim.lr_scheduler._LRScheduler`):
- The scheduler to wrap.
- optimizers (one or a list of `torch.optim.Optimizer`):
- The optimizers used.
- step_with_optimizer (`bool`, *optional*, defaults to `True`):
- Whether or not the scheduler should be stepped at each optimizer step.
- split_batches (`bool`, *optional*, defaults to `False`):
- Whether or not the dataloaders split one batch across the different processes (so batch size is the same
- regardless of the number of processes) or create batches on each process (so batch size is the original
- batch size multiplied by the number of processes).
- """
- def __init__(self, scheduler, optimizers, step_with_optimizer: bool = True, split_batches: bool = False):
- self.scheduler = scheduler
- self.optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers]
- self.split_batches = split_batches
- self.step_with_optimizer = step_with_optimizer
- self.gradient_state = GradientState()
- def step(self, *args, **kwargs):
- if not self.step_with_optimizer:
- # No link between scheduler and optimizer -> just step
- self.scheduler.step(*args, **kwargs)
- return
- # Otherwise, first make sure the optimizer was stepped.
- if not self.gradient_state.sync_gradients:
- if self.gradient_state.adjust_scheduler:
- self.scheduler._step_count += 1
- return
- for opt in self.optimizers:
- if opt.step_was_skipped:
- return
- if self.split_batches:
- # Split batches -> the training dataloader batch size is not changed so one step per training step
- self.scheduler.step(*args, **kwargs)
- else:
- # Otherwise the training dataloader batch size was multiplied by `num_processes`, so we need to do
- # num_processes steps per training step
- num_processes = AcceleratorState().num_processes
- for _ in range(num_processes):
- # Special case when using OneCycle and `drop_last` was not used
- if hasattr(self.scheduler, "total_steps"):
- if self.scheduler._step_count <= self.scheduler.total_steps:
- self.scheduler.step(*args, **kwargs)
- else:
- self.scheduler.step(*args, **kwargs)
- # Passthroughs
- def get_last_lr(self):
- return self.scheduler.get_last_lr()
- def state_dict(self):
- return self.scheduler.state_dict()
- def load_state_dict(self, state_dict):
- self.scheduler.load_state_dict(state_dict)
- def get_lr(self):
- return self.scheduler.get_lr()
- def print_lr(self, *args, **kwargs):
- return self.scheduler.print_lr(*args, **kwargs)
|