deepspeed.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. # Copyright 2021 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import base64
  15. import json
  16. import os
  17. from copy import deepcopy
  18. from torch import optim
  19. from ..optimizer import AcceleratedOptimizer
  20. from ..scheduler import AcceleratedScheduler
  21. from .dataclasses import DistributedType
  22. from .imports import is_bnb_available
  23. from .versions import compare_versions
  24. def map_pytorch_optim_to_deepspeed(optimizer):
  25. """
  26. Args:
  27. optimizer: torch.optim.Optimizer
  28. Returns the DeepSeedCPUOptimizer (deepspeed.ops) version of the optimizer.
  29. """
  30. defaults = {k: v for k, v in optimizer.defaults.items() if k in ["lr", "weight_decay"]}
  31. # Select the DeepSpeedCPUOptimizer based on the original optimizer class.
  32. # DeepSpeedCPUAdam is the default
  33. from deepspeed.ops.adam import DeepSpeedCPUAdam
  34. optimizer_class = DeepSpeedCPUAdam
  35. # For DeepSpeedCPUAdam (adamw_mode)
  36. if compare_versions("deepspeed", ">=", "0.3.1"):
  37. defaults["adamw_mode"] = False
  38. is_adaw = isinstance(optimizer, optim.AdamW)
  39. if is_bnb_available() and not is_adaw:
  40. import bitsandbytes.optim as bnb_opt
  41. if isinstance(optimizer, (bnb_opt.AdamW, bnb_opt.AdamW32bit)):
  42. try:
  43. is_adaw = optimizer.optim_bits == 32
  44. except AttributeError:
  45. is_adaw = optimizer.args.optim_bits == 32
  46. else:
  47. is_adaw = False
  48. if is_adaw:
  49. defaults["adamw_mode"] = True
  50. # For DeepSpeedCPUAdagrad
  51. if compare_versions("deepspeed", ">=", "0.5.5"):
  52. # Check if the optimizer is PyTorch's Adagrad.
  53. is_ada = isinstance(optimizer, optim.Adagrad)
  54. # If not, and bitsandbytes is available,
  55. # # check if the optimizer is the 32-bit bitsandbytes Adagrad.
  56. if is_bnb_available() and not is_ada:
  57. import bitsandbytes.optim as bnb_opt
  58. if isinstance(optimizer, (bnb_opt.Adagrad, bnb_opt.Adagrad32bit)):
  59. try:
  60. is_ada = optimizer.optim_bits == 32
  61. except AttributeError:
  62. is_ada = optimizer.args.optim_bits == 32
  63. if is_ada:
  64. from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad
  65. optimizer_class = DeepSpeedCPUAdagrad
  66. # For DeepSpeedCPULion
  67. if is_bnb_available(min_version="0.38.0") and compare_versions("deepspeed", ">=", "0.11.0"):
  68. from bitsandbytes.optim import Lion, Lion32bit
  69. if isinstance(optimizer, (Lion, Lion32bit)):
  70. try:
  71. is_bnb_32bits = optimizer.optim_bits == 32
  72. except AttributeError:
  73. is_bnb_32bits = optimizer.args.optim_bits == 32
  74. if is_bnb_32bits:
  75. from deepspeed.ops.lion import DeepSpeedCPULion
  76. optimizer_class = DeepSpeedCPULion
  77. return optimizer_class(optimizer.param_groups, **defaults)
  78. def get_active_deepspeed_plugin(state):
  79. """
  80. Returns the currently active DeepSpeedPlugin.
  81. Raises:
  82. ValueError: If DeepSpeed was not enabled and this function is called.
  83. """
  84. if state.distributed_type != DistributedType.DEEPSPEED:
  85. raise ValueError(
  86. "Couldn't retrieve the active `DeepSpeedPlugin` as none were enabled. "
  87. "Please make sure that either `Accelerator` is configured for `deepspeed` "
  88. "or make sure that the desired `DeepSpeedPlugin` has been enabled (`AcceleratorState().select_deepspeed_plugin(name)`) "
  89. "before calling this function."
  90. )
  91. if not isinstance(state.deepspeed_plugins, dict):
  92. return state.deepspeed_plugins
  93. return next(plugin for plugin in state.deepspeed_plugins.values() if plugin.selected)
  94. class HfDeepSpeedConfig:
  95. """
  96. This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage.
  97. A `weakref` of this object is stored in the module's globals to be able to access the config from areas where
  98. things like the Trainer object is not available (e.g. `from_pretrained` and `_get_resized_embeddings`). Therefore
  99. it's important that this object remains alive while the program is still running.
  100. [`Trainer`] uses the `HfTrainerDeepSpeedConfig` subclass instead. That subclass has logic to sync the configuration
  101. with values of [`TrainingArguments`] by replacing special placeholder values: `"auto"`. Without this special logic
  102. the DeepSpeed configuration is not modified in any way.
  103. Args:
  104. config_file_or_dict (`Union[str, Dict]`): path to DeepSpeed config file or dict.
  105. """
  106. def __init__(self, config_file_or_dict):
  107. if isinstance(config_file_or_dict, dict):
  108. # Don't modify user's data should they want to reuse it (e.g. in tests), because once we
  109. # modified it, it will not be accepted here again, since `auto` values would have been overridden
  110. config = deepcopy(config_file_or_dict)
  111. elif os.path.exists(config_file_or_dict):
  112. with open(config_file_or_dict, encoding="utf-8") as f:
  113. config = json.load(f)
  114. else:
  115. try:
  116. try:
  117. # First try parsing as JSON directly
  118. config = json.loads(config_file_or_dict)
  119. except json.JSONDecodeError:
  120. # If that fails, try base64 decoding
  121. config_decoded = base64.urlsafe_b64decode(config_file_or_dict).decode("utf-8")
  122. config = json.loads(config_decoded)
  123. except (UnicodeDecodeError, AttributeError, ValueError):
  124. raise ValueError(
  125. f"Expected a string path to an existing deepspeed config, or a dictionary, or a base64 encoded string. Received: {config_file_or_dict}"
  126. )
  127. self.config = config
  128. self.set_stage_and_offload()
  129. def set_stage_and_offload(self):
  130. # zero stage - this is done as early as possible, before model is created, to allow
  131. # ``is_deepspeed_zero3_enabled`` query and getting to the early deepspeed config object
  132. # during ``zero.Init()`` which needs to know the dtype, and some other hparams.
  133. self._stage = self.get_value("zero_optimization.stage", -1)
  134. # offload
  135. self._offload = False
  136. if self.is_zero2() or self.is_zero3():
  137. offload_devices_valid = set(["cpu", "nvme"])
  138. offload_devices = set(
  139. [
  140. self.get_value("zero_optimization.offload_optimizer.device"),
  141. self.get_value("zero_optimization.offload_param.device"),
  142. ]
  143. )
  144. if len(offload_devices & offload_devices_valid) > 0:
  145. self._offload = True
  146. def find_config_node(self, ds_key_long):
  147. config = self.config
  148. # find the config node of interest if it exists
  149. nodes = ds_key_long.split(".")
  150. ds_key = nodes.pop()
  151. for node in nodes:
  152. config = config.get(node)
  153. if config is None:
  154. return None, ds_key
  155. return config, ds_key
  156. def get_value(self, ds_key_long, default=None):
  157. """
  158. Returns the set value or `default` if no value is set
  159. """
  160. config, ds_key = self.find_config_node(ds_key_long)
  161. if config is None:
  162. return default
  163. return config.get(ds_key, default)
  164. def del_config_sub_tree(self, ds_key_long, must_exist=False):
  165. """
  166. Deletes a sub-section of the config file if it's found.
  167. Unless `must_exist` is `True` the section doesn't have to exist.
  168. """
  169. config = self.config
  170. # find the config node of interest if it exists
  171. nodes = ds_key_long.split(".")
  172. for node in nodes:
  173. parent_config = config
  174. config = config.get(node)
  175. if config is None:
  176. if must_exist:
  177. raise ValueError(f"Can't find {ds_key_long} entry in the config: {self.config}")
  178. else:
  179. return
  180. # if found remove it
  181. if parent_config is not None:
  182. parent_config.pop(node)
  183. def is_true(self, ds_key_long):
  184. """
  185. Returns `True`/``False` only if the value is set, always `False` otherwise. So use this method to ask the very
  186. specific question of whether the value is set to `True` (and it's not set to `False`` or isn't set).
  187. """
  188. value = self.get_value(ds_key_long)
  189. return False if value is None else bool(value)
  190. def is_false(self, ds_key_long):
  191. """
  192. Returns `True`/``False` only if the value is set, always `False` otherwise. So use this method to ask the very
  193. specific question of whether the value is set to `False` (and it's not set to `True`` or isn't set).
  194. """
  195. value = self.get_value(ds_key_long)
  196. return False if value is None else not bool(value)
  197. def is_zero2(self):
  198. return self._stage == 2
  199. def is_zero3(self):
  200. return self._stage == 3
  201. def is_offload(self):
  202. return self._offload
  203. class DeepSpeedEngineWrapper:
  204. """
  205. Internal wrapper for deepspeed.runtime.engine.DeepSpeedEngine. This is used to follow conventional training loop.
  206. Args:
  207. engine (deepspeed.runtime.engine.DeepSpeedEngine): deepspeed engine to wrap
  208. """
  209. def __init__(self, engine):
  210. self.engine = engine
  211. def backward(self, loss, sync_gradients=True, **kwargs):
  212. # Set gradient accumulation boundary based on Accelerate's sync_gradients state
  213. # This tells DeepSpeed whether this is the final micro-batch before gradient sync
  214. self.engine.set_gradient_accumulation_boundary(is_boundary=sync_gradients)
  215. # runs backpropagation and handles mixed precision
  216. self.engine.backward(loss, **kwargs)
  217. # Only perform step and related operations at gradient accumulation boundaries
  218. if sync_gradients:
  219. # Deepspeed's `engine.step` performs the following operations:
  220. # - gradient accumulation check
  221. # - gradient clipping
  222. # - optimizer step
  223. # - zero grad
  224. # - checking overflow
  225. # - lr_scheduler step (only if engine.lr_scheduler is not None)
  226. self.engine.step()
  227. # and this plugin overrides the above calls with no-ops when Accelerate runs under
  228. # Deepspeed, but allows normal functionality for non-Deepspeed cases thus enabling a simple
  229. # training loop that works transparently under many training regimes.
  230. def get_global_grad_norm(self):
  231. """Get the global gradient norm from DeepSpeed engine."""
  232. grad_norm = self.engine.get_global_grad_norm()
  233. # Convert to scalar if it's a tensor
  234. if hasattr(grad_norm, "item"):
  235. return grad_norm.item()
  236. return grad_norm
  237. class DeepSpeedOptimizerWrapper(AcceleratedOptimizer):
  238. """
  239. Internal wrapper around a deepspeed optimizer.
  240. Args:
  241. optimizer (`torch.optim.optimizer.Optimizer`):
  242. The optimizer to wrap.
  243. """
  244. def __init__(self, optimizer):
  245. super().__init__(optimizer, device_placement=False, scaler=None)
  246. self.__has_overflow__ = hasattr(self.optimizer, "overflow")
  247. def zero_grad(self, set_to_none=None):
  248. pass # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
  249. def step(self):
  250. pass # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
  251. @property
  252. def step_was_skipped(self):
  253. """Whether or not the optimizer step was done, or skipped because of gradient overflow."""
  254. if self.__has_overflow__:
  255. return self.optimizer.overflow
  256. return False
  257. class DeepSpeedSchedulerWrapper(AcceleratedScheduler):
  258. """
  259. Internal wrapper around a deepspeed scheduler.
  260. Args:
  261. scheduler (`torch.optim.lr_scheduler.LambdaLR`):
  262. The scheduler to wrap.
  263. optimizers (one or a list of `torch.optim.Optimizer`):
  264. """
  265. def __init__(self, scheduler, optimizers):
  266. super().__init__(scheduler, optimizers)
  267. def step(self):
  268. pass # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
  269. class DummyOptim:
  270. """
  271. Dummy optimizer presents model parameters or param groups, this is primarily used to follow conventional training
  272. loop when optimizer config is specified in the deepspeed config file.
  273. Args:
  274. lr (float):
  275. Learning rate.
  276. params (iterable): iterable of parameters to optimize or dicts defining
  277. parameter groups
  278. weight_decay (float):
  279. Weight decay.
  280. **kwargs (additional keyword arguments, *optional*):
  281. Other arguments.
  282. """
  283. def __init__(self, params, lr=0.001, weight_decay=0, **kwargs):
  284. self.params = params
  285. self.lr = lr
  286. self.weight_decay = weight_decay
  287. self.kwargs = kwargs
  288. class DummyScheduler:
  289. """
  290. Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training
  291. loop when scheduler config is specified in the deepspeed config file.
  292. Args:
  293. optimizer (`torch.optim.optimizer.Optimizer`):
  294. The optimizer to wrap.
  295. total_num_steps (int, *optional*):
  296. Total number of steps.
  297. warmup_num_steps (int, *optional*):
  298. Number of steps for warmup.
  299. lr_scheduler_callable (callable, *optional*):
  300. A callable function that creates an LR Scheduler. It accepts only one argument `optimizer`.
  301. **kwargs (additional keyword arguments, *optional*):
  302. Other arguments.
  303. """
  304. def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0, lr_scheduler_callable=None, **kwargs):
  305. self.optimizer = optimizer
  306. self.total_num_steps = total_num_steps
  307. self.warmup_num_steps = warmup_num_steps
  308. self.lr_scheduler_callable = lr_scheduler_callable
  309. self.kwargs = kwargs