optimizer.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # Copyright 2021 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. import torch
  16. from .state import AcceleratorState, GradientState
  17. from .utils import DistributedType, honor_type, is_lomo_available, is_torch_xla_available
  18. if is_torch_xla_available():
  19. import torch_xla.core.xla_model as xm
  20. import torch_xla.runtime as xr
  21. def move_to_device(state, device):
  22. if isinstance(state, (list, tuple)):
  23. return honor_type(state, (move_to_device(t, device) for t in state))
  24. elif isinstance(state, dict):
  25. return type(state)({k: move_to_device(v, device) for k, v in state.items()})
  26. elif isinstance(state, torch.Tensor):
  27. return state.to(device)
  28. return state
  29. class AcceleratedOptimizer(torch.optim.Optimizer):
  30. """
  31. Internal wrapper around a torch optimizer.
  32. Conditionally will perform `step` and `zero_grad` if gradients should be synchronized when performing gradient
  33. accumulation.
  34. Args:
  35. optimizer (`torch.optim.optimizer.Optimizer`):
  36. The optimizer to wrap.
  37. device_placement (`bool`, *optional*, defaults to `True`):
  38. Whether or not the optimizer should handle device placement. If so, it will place the state dictionary of
  39. `optimizer` on the right device.
  40. scaler (`torch.amp.GradScaler` or `torch.cuda.amp.GradScaler`, *optional*):
  41. The scaler to use in the step function if training with mixed precision.
  42. """
  43. def __init__(self, optimizer, device_placement=True, scaler=None):
  44. self.optimizer = optimizer
  45. self.scaler = scaler
  46. self.accelerator_state = AcceleratorState()
  47. self.gradient_state = GradientState()
  48. self.device_placement = device_placement
  49. self._is_overflow = False
  50. if self.scaler is not None:
  51. self._accelerate_step_called = False
  52. self._optimizer_original_step_method = self.optimizer.step
  53. self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
  54. # Handle device placement
  55. if device_placement:
  56. state_dict = self.optimizer.state_dict()
  57. if self.accelerator_state.distributed_type == DistributedType.XLA:
  58. xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
  59. else:
  60. state_dict = move_to_device(state_dict, self.accelerator_state.device)
  61. self.optimizer.load_state_dict(state_dict)
  62. @property
  63. def state(self):
  64. return self.optimizer.state
  65. @state.setter
  66. def state(self, state):
  67. self.optimizer.state = state
  68. @property
  69. def param_groups(self):
  70. return self.optimizer.param_groups
  71. @param_groups.setter
  72. def param_groups(self, param_groups):
  73. self.optimizer.param_groups = param_groups
  74. @property
  75. def defaults(self):
  76. return self.optimizer.defaults
  77. @defaults.setter
  78. def defaults(self, defaults):
  79. self.optimizer.defaults = defaults
  80. def add_param_group(self, param_group):
  81. self.optimizer.add_param_group(param_group)
  82. def load_state_dict(self, state_dict):
  83. if self.accelerator_state.distributed_type == DistributedType.XLA and self.device_placement:
  84. xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device)
  85. self.optimizer.load_state_dict(state_dict)
  86. def state_dict(self):
  87. return self.optimizer.state_dict()
  88. def zero_grad(self, set_to_none=None):
  89. if self.gradient_state.sync_gradients:
  90. accept_arg = "set_to_none" in inspect.signature(self.optimizer.zero_grad).parameters
  91. if accept_arg:
  92. if set_to_none is None:
  93. set_to_none = True
  94. self.optimizer.zero_grad(set_to_none=set_to_none)
  95. else:
  96. if set_to_none is not None:
  97. raise ValueError("`set_to_none` for Optimizer.zero_grad` is not supported by this optimizer.")
  98. self.optimizer.zero_grad()
  99. def train(self):
  100. """
  101. Sets the optimizer to "train" mode. Useful for optimizers like `schedule_free`
  102. """
  103. if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
  104. self.optimizer.train()
  105. elif (
  106. hasattr(self.optimizer, "optimizer")
  107. and hasattr(self.optimizer.optimizer, "train")
  108. and callable(self.optimizer.optimizer.train)
  109. ):
  110. # the deepspeed optimizer further wraps the optimizer
  111. self.optimizer.optimizer.train()
  112. def eval(self):
  113. """
  114. Sets the optimizer to "eval" mode. Useful for optimizers like `schedule_free`
  115. """
  116. if hasattr(self.optimizer, "eval") and callable(self.optimizer.eval):
  117. self.optimizer.eval()
  118. def step(self, closure=None):
  119. if is_lomo_available():
  120. from lomo_optim import AdaLomo, Lomo
  121. if (
  122. not self.gradient_state.is_xla_gradients_synced
  123. and self.accelerator_state.distributed_type == DistributedType.XLA
  124. ):
  125. gradients = xm._fetch_gradients(self.optimizer)
  126. xm.all_reduce("sum", gradients, scale=1.0 / xr.world_size())
  127. self.gradient_state.is_xla_gradients_synced = True
  128. if is_lomo_available():
  129. # `step` should be a no-op for LOMO optimizers.
  130. if isinstance(self.optimizer, (Lomo, AdaLomo)):
  131. return
  132. if self.gradient_state.sync_gradients:
  133. if self.scaler is not None:
  134. self.optimizer.step = self._optimizer_patched_step_method
  135. self.scaler.step(self.optimizer, closure)
  136. self.scaler.update()
  137. if not self._accelerate_step_called:
  138. # If the optimizer step was skipped, gradient overflow was detected.
  139. self._is_overflow = True
  140. else:
  141. self._is_overflow = False
  142. # Reset the step method to the original one
  143. self.optimizer.step = self._optimizer_original_step_method
  144. # Reset the indicator
  145. self._accelerate_step_called = False
  146. else:
  147. self.optimizer.step(closure)
  148. if self.accelerator_state.distributed_type == DistributedType.XLA:
  149. self.gradient_state.is_xla_gradients_synced = False
  150. def _switch_parameters(self, parameters_map):
  151. for param_group in self.optimizer.param_groups:
  152. param_group["params"] = [parameters_map.get(p, p) for p in param_group["params"]]
  153. @property
  154. def step_was_skipped(self):
  155. """Whether or not the optimizer step was skipped."""
  156. return self._is_overflow
  157. def __getstate__(self):
  158. _ignored_keys = [
  159. "_accelerate_step_called",
  160. "_optimizer_original_step_method",
  161. "_optimizer_patched_step_method",
  162. ]
  163. return {k: v for k, v in self.__dict__.items() if k not in _ignored_keys}
  164. def __setstate__(self, state):
  165. self.__dict__.update(state)
  166. if self.scaler is not None:
  167. self._accelerate_step_called = False
  168. self._optimizer_original_step_method = self.optimizer.step
  169. self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step)
  170. def patch_optimizer_step(accelerated_optimizer: AcceleratedOptimizer, method):
  171. def patched_step(*args, **kwargs):
  172. accelerated_optimizer._accelerate_step_called = True
  173. return method(*args, **kwargs)
  174. return patched_step