deepspeed.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Integration with Deepspeed
  16. """
  17. import copy
  18. import importlib.metadata as importlib_metadata
  19. import importlib.util
  20. import weakref
  21. from functools import partialmethod
  22. from ..dependency_versions_check import dep_version_check
  23. from ..utils import is_accelerate_available, is_torch_available, logging
  24. if is_torch_available():
  25. import torch
  26. from torch import nn
  27. logger = logging.get_logger(__name__)
  28. def is_deepspeed_available():
  29. package_exists = importlib.util.find_spec("deepspeed") is not None
  30. # Check we're not importing a "deepspeed" directory somewhere but the actual library by trying to grab the version
  31. # AND checking it has an author field in the metadata that is HuggingFace.
  32. if package_exists:
  33. try:
  34. _ = importlib_metadata.metadata("deepspeed")
  35. return True
  36. except importlib_metadata.PackageNotFoundError:
  37. return False
  38. if is_accelerate_available() and is_deepspeed_available():
  39. from accelerate.utils.deepspeed import HfDeepSpeedConfig as DeepSpeedConfig
  40. else:
  41. # Inherits from a dummy `object` if accelerate is not available, so that python succeeds to import this file.
  42. # Deepspeed glue code will never inherit this dummy object as it checks if accelerate is available.
  43. from builtins import object as DeepSpeedConfig
  44. class HfDeepSpeedConfig(DeepSpeedConfig):
  45. """
  46. This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage.
  47. A `weakref` of this object is stored in the module's globals to be able to access the config from areas where
  48. things like the Trainer object is not available (e.g. `from_pretrained` and `_get_resized_embeddings`). Therefore
  49. it's important that this object remains alive while the program is still running.
  50. [`Trainer`] uses the `HfTrainerDeepSpeedConfig` subclass instead. That subclass has logic to sync the configuration
  51. with values of [`TrainingArguments`] by replacing special placeholder values: `"auto"`. Without this special logic
  52. the DeepSpeed configuration is not modified in any way.
  53. Args:
  54. config_file_or_dict (`Union[str, Dict]`): path to DeepSpeed config file or dict.
  55. """
  56. def __init__(self, config_file_or_dict):
  57. # set global weakref object
  58. set_hf_deepspeed_config(self)
  59. dep_version_check("accelerate")
  60. dep_version_check("deepspeed")
  61. super().__init__(config_file_or_dict)
  62. class HfTrainerDeepSpeedConfig(HfDeepSpeedConfig):
  63. """
  64. The `HfTrainerDeepSpeedConfig` object is meant to be created during `TrainingArguments` object creation and has the
  65. same lifespan as the latter.
  66. """
  67. def __init__(self, config_file_or_dict):
  68. super().__init__(config_file_or_dict)
  69. self._dtype = None
  70. self.mismatches = []
  71. def dtype(self):
  72. if self._dtype is None:
  73. raise ValueError("trainer_config_process() wasn't called yet to tell dtype")
  74. return self._dtype
  75. def is_auto(self, ds_key_long):
  76. val = self.get_value(ds_key_long)
  77. if val is None:
  78. return False
  79. else:
  80. return val == "auto"
  81. def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True):
  82. """
  83. A utility method that massages the config file and can optionally verify that the values match.
  84. 1. Replace "auto" values with `TrainingArguments` value.
  85. 2. If it wasn't "auto" and `must_match` is true, then check that DS config matches Trainer
  86. config values and if mismatched add the entry to `self.mismatched` - will assert during
  87. `trainer_config_finalize` for one or more mismatches.
  88. """
  89. config, ds_key = self.find_config_node(ds_key_long)
  90. if config is None:
  91. return
  92. if config.get(ds_key) == "auto":
  93. config[ds_key] = hf_val
  94. return
  95. if not must_match:
  96. return
  97. ds_val = config.get(ds_key)
  98. if ds_val is not None and ds_val != hf_val:
  99. self.mismatches.append(f"- ds {ds_key_long}={ds_val} vs hf {hf_key}={hf_val}")
  100. fill_only = partialmethod(fill_match, must_match=False)
  101. def trainer_config_process(self, args, auto_find_batch_size=False):
  102. """
  103. Adjust the config with `TrainingArguments` values. This stage is run during `TrainingArguments` object
  104. creation.
  105. """
  106. # DeepSpeed does:
  107. # train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps
  108. train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps
  109. self.fill_match(
  110. "train_micro_batch_size_per_gpu",
  111. args.per_device_train_batch_size,
  112. "per_device_train_batch_size",
  113. not auto_find_batch_size,
  114. )
  115. self.fill_match(
  116. "gradient_accumulation_steps",
  117. args.gradient_accumulation_steps,
  118. "gradient_accumulation_steps",
  119. )
  120. self.fill_match(
  121. "train_batch_size",
  122. train_batch_size,
  123. "train_batch_size (calculated)",
  124. not auto_find_batch_size,
  125. )
  126. self.fill_match("gradient_clipping", args.max_grad_norm, "max_grad_norm")
  127. self.fill_match("optimizer.params.lr", args.learning_rate, "learning_rate")
  128. self.fill_match(
  129. "optimizer.params.betas",
  130. [args.adam_beta1, args.adam_beta2],
  131. "adam_beta1+adam_beta2",
  132. )
  133. self.fill_match("optimizer.params.eps", args.adam_epsilon, "adam_epsilon")
  134. self.fill_match("optimizer.params.weight_decay", args.weight_decay, "weight_decay")
  135. self.fill_only("scheduler.params.warmup_min_lr", 0) # not a trainer arg
  136. self.fill_match("scheduler.params.warmup_max_lr", args.learning_rate, "learning_rate")
  137. # total_num_steps - will get set in trainer_config_finalize
  138. # fp16
  139. if args.fp16 or args.fp16_full_eval:
  140. fp16_backend = "apex" if args.fp16_backend == "apex" else "amp"
  141. else:
  142. fp16_backend = None
  143. if args.save_on_each_node:
  144. # deepspeed uses shared storage by default. Let's override this setting if save_on_each_node == True
  145. self.config["checkpoint"] = self.config.get("checkpoint", {})
  146. self.config["checkpoint"]["use_node_local_storage"] = args.save_on_each_node
  147. # amp: similar to the pytorch native amp - it has a bunch of optional params but we won't set
  148. # any here unless the user did the work
  149. self.fill_match(
  150. "fp16.enabled",
  151. ((args.fp16 or args.fp16_full_eval) and fp16_backend == "amp"),
  152. "fp16|fp16_full_eval+fp16_backend(amp)",
  153. )
  154. # apex: delegates amp work to apex (which needs to be available), but it cannot be used with any
  155. # ZeRO features
  156. self.fill_match("amp.enabled", fp16_backend == "apex", "fp16+fp16_backend(apex)")
  157. self.fill_match("amp.opt_level", args.fp16_opt_level, "fp16_opt_level")
  158. self.fill_match("bf16.enabled", (args.bf16 or args.bf16_full_eval), "bf16|bf16_full_eval")
  159. # deepspeed's default mode is fp16 unless there is a config that says differently
  160. if self.is_true("bf16.enabled"):
  161. self._dtype = torch.bfloat16
  162. elif self.is_false("fp16.enabled"):
  163. self._dtype = torch.float32
  164. else:
  165. self._dtype = torch.float16
  166. def trainer_config_finalize(self, args, model, num_training_steps):
  167. """
  168. This stage is run after we have the model and know num_training_steps.
  169. Now we can complete the configuration process.
  170. """
  171. # zero
  172. # deal with config keys that use `auto` value and rely on model's hidden_size
  173. hidden_size_based_keys = [
  174. "zero_optimization.reduce_bucket_size",
  175. "zero_optimization.stage3_prefetch_bucket_size",
  176. "zero_optimization.stage3_param_persistence_threshold",
  177. ]
  178. hidden_size_auto_keys = [x for x in hidden_size_based_keys if self.is_auto(x)]
  179. if len(hidden_size_auto_keys) > 0:
  180. hidden_size = None
  181. if hasattr(model, "config"):
  182. if hasattr(model.config, "hidden_size"):
  183. hidden_size = model.config.hidden_size
  184. elif hasattr(model.config, "hidden_sizes"):
  185. # if there are many hidden sizes pick the largest one
  186. hidden_size = max(model.config.hidden_sizes)
  187. elif hasattr(model.config, "text_config") and hasattr(model.config.text_config, "hidden_size"):
  188. hidden_size = model.config.text_config.hidden_size
  189. elif hasattr(model.config, "text_config") and hasattr(model.config.text_config, "hidden_sizes"):
  190. # if there are many hidden sizes pick the largest one
  191. hidden_size = max(model.config.text_config.hidden_sizes)
  192. if hidden_size is None:
  193. raise ValueError(
  194. "The model's config file has neither `hidden_size` nor `hidden_sizes` entry, "
  195. "therefore it's not possible to automatically fill out the following `auto` entries "
  196. f"in the DeepSpeed config file: {hidden_size_auto_keys}. You can fix that by replacing "
  197. "`auto` values for these keys with an integer value of your choice."
  198. )
  199. self.fill_only("zero_optimization.reduce_bucket_size", hidden_size * hidden_size)
  200. if self.is_zero3():
  201. # automatically assign the optimal config values based on model config
  202. self.fill_only(
  203. "zero_optimization.stage3_prefetch_bucket_size",
  204. int(0.9 * hidden_size * hidden_size),
  205. )
  206. self.fill_only(
  207. "zero_optimization.stage3_param_persistence_threshold",
  208. 10 * hidden_size,
  209. )
  210. # scheduler
  211. self.fill_match(
  212. "scheduler.params.total_num_steps",
  213. num_training_steps,
  214. "num_training_steps (calculated)",
  215. )
  216. self.fill_match(
  217. "scheduler.params.warmup_num_steps",
  218. args.get_warmup_steps(num_training_steps),
  219. "warmup_steps",
  220. )
  221. if len(self.mismatches) > 0:
  222. mismatches = "\n".join(self.mismatches)
  223. raise ValueError(
  224. "Please correct the following DeepSpeed config values that mismatch TrainingArguments"
  225. f" values:\n{mismatches}\nThe easiest method is to set these DeepSpeed config values to 'auto'."
  226. )
  227. # keep the config object global to be able to access it anywhere during TrainingArguments life-cycle
  228. _hf_deepspeed_config_weak_ref = None
  229. def set_hf_deepspeed_config(hf_deepspeed_config_obj):
  230. # this is a special weakref global object to allow us to get to Deepspeed config from APIs
  231. # that don't have an easy way to get to the Deepspeed config outside of the Trainer domain.
  232. global _hf_deepspeed_config_weak_ref
  233. # will go away automatically when HfDeepSpeedConfig is destroyed (when TrainingArguments is destroyed)
  234. _hf_deepspeed_config_weak_ref = weakref.ref(hf_deepspeed_config_obj)
  235. def unset_hf_deepspeed_config():
  236. # useful for unit tests to ensure the global state doesn't leak - call from `tearDown` method
  237. global _hf_deepspeed_config_weak_ref
  238. _hf_deepspeed_config_weak_ref = None
  239. def is_deepspeed_zero3_enabled():
  240. if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
  241. return _hf_deepspeed_config_weak_ref().is_zero3()
  242. else:
  243. return False
  244. def deepspeed_config():
  245. if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
  246. return _hf_deepspeed_config_weak_ref().config
  247. else:
  248. return None
  249. def _load_state_dict_into_zero3_model(model_to_load, state_dict):
  250. """
  251. Loads state dict into a model specifically for Zero3, since DeepSpeed does not support the `transformers`
  252. tensor parallelism API.
  253. Nearly identical code to PyTorch's `_load_from_state_dict`
  254. """
  255. # copy state_dict so `_load_state_dict_into_zero3_model` can modify it
  256. metadata = getattr(state_dict, "_metadata", None)
  257. state_dict = state_dict.copy()
  258. if metadata is not None:
  259. state_dict._metadata = metadata
  260. error_msgs = []
  261. # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
  262. # so we need to apply the function recursively.
  263. def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
  264. local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
  265. local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
  266. args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
  267. # Parameters of module and children will start with prefix. We can exit early if there are none in this
  268. # state_dict
  269. if is_deepspeed_zero3_enabled() and len([key for key in state_dict if key.startswith(prefix)]) > 0:
  270. import deepspeed
  271. # In sharded models, each shard has only part of the full state_dict, so only gather
  272. # parameters that are in the current state_dict.
  273. named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
  274. params_to_gather = [named_parameters[k] for k in state_dict if k in named_parameters]
  275. if len(params_to_gather) > 0:
  276. # because zero3 puts placeholders in model params, this context
  277. # manager gathers (unpartitions) the params of the current layer, then loads from
  278. # the state dict and then re-partitions them again
  279. with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
  280. if torch.distributed.get_rank() == 0:
  281. module._load_from_state_dict(*args)
  282. for name, child in module._modules.items():
  283. if child is not None:
  284. load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
  285. load(model_to_load, state_dict, assign_to_params_buffers=False)
  286. return error_msgs
  287. def deepspeed_optim_sched(trainer, hf_deepspeed_config, args, num_training_steps, model_parameters):
  288. """
  289. A convenience wrapper that deals with optimizer and lr scheduler configuration.
  290. """
  291. from accelerate.utils import DummyOptim, DummyScheduler
  292. config = hf_deepspeed_config.config
  293. # Mixing and matching DS schedulers and optimizers is supported unless Offload is enabled in which case it's:
  294. # 1. DS scheduler + DS optimizer: Yes
  295. # 2. HF scheduler + HF optimizer: Mostly*
  296. # 3. DS scheduler + HF optimizer: Mostly*
  297. # 4. HF scheduler + DS optimizer: Yes
  298. #
  299. # Mostly*: All non-native DeepSpeed optimizers that have both CPU and GPU implementation should work (except LAMB)
  300. optimizer = None
  301. if "optimizer" in config:
  302. if args.optim == "adafactor":
  303. raise ValueError(
  304. "--adafactor was passed, but also found `optimizer` configured in the DeepSpeed config. "
  305. "Only one optimizer can be configured."
  306. )
  307. optimizer = DummyOptim(params=model_parameters)
  308. else:
  309. if hf_deepspeed_config.is_offload():
  310. logger.info(
  311. "Detected ZeRO Offload and non-DeepSpeed optimizers: This combination should work as long as the"
  312. " custom optimizer has both CPU and GPU implementation (except LAMB)"
  313. )
  314. # ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch.
  315. # But trainer uses AdamW by default.
  316. optimizer = trainer.create_optimizer()
  317. # To use other optimizers requires voiding warranty with: `zero_allow_untested_optimizer`
  318. config["zero_allow_untested_optimizer"] = True
  319. lr_scheduler = None
  320. if "scheduler" in config:
  321. lr_scheduler = DummyScheduler(optimizer)
  322. else:
  323. if isinstance(optimizer, DummyOptim):
  324. def _lr_scheduler_callable(optimizer):
  325. # create a shallow copy first, so later modifications do not affect original trainer
  326. trainer_copy = copy.copy(trainer)
  327. # at the time _lr_scheduler_callable is called, trainer.lr_scheduler has been set
  328. # update it to None so that we can re-create a new scheduler
  329. trainer_copy.lr_scheduler = None
  330. lr_scheduler = trainer_copy.create_scheduler(
  331. num_training_steps=num_training_steps, optimizer=optimizer
  332. )
  333. return lr_scheduler
  334. lr_scheduler = DummyScheduler(optimizer, lr_scheduler_callable=_lr_scheduler_callable)
  335. else:
  336. lr_scheduler = trainer.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
  337. return optimizer, lr_scheduler
  338. def deepspeed_init(trainer, num_training_steps, inference=False):
  339. """
  340. Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.
  341. If `resume_from_checkpoint` was passed then an attempt to resume from a previously saved checkpoint will be made.
  342. Args:
  343. trainer: Trainer object
  344. num_training_steps: per single gpu
  345. resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load
  346. inference: launch in inference mode (no optimizer and no lr scheduler)
  347. auto_find_batch_size: whether to ignore the `train_micro_batch_size_per_gpu` argument as it's being
  348. set automatically by the auto batch size finder
  349. Returns: optimizer, lr_scheduler
  350. We may use `deepspeed_init` more than once during the life of Trainer, when we do - it's a temp hack based on:
  351. https://github.com/deepspeedai/DeepSpeed/issues/1394#issuecomment-937405374 until Deepspeed fixes a bug where it
  352. can't resume from a checkpoint after it did some stepping https://github.com/deepspeedai/DeepSpeed/issues/1612
  353. """
  354. from deepspeed.utils import logger as ds_logger
  355. model = trainer.model
  356. args = trainer.args
  357. hf_deepspeed_config = trainer.accelerator.state.deepspeed_plugin.hf_ds_config
  358. # resume config update - some bits like `model` and `num_training_steps` only become available during train
  359. hf_deepspeed_config.trainer_config_finalize(args, model, num_training_steps)
  360. # set the Deepspeed log level consistent with the Trainer
  361. ds_logger.setLevel(args.get_process_log_level())
  362. if inference:
  363. # only Z3 makes sense for the inference
  364. if not hf_deepspeed_config.is_zero3():
  365. raise ValueError("ZeRO inference only makes sense with ZeRO Stage 3 - please adjust your config")
  366. # in case the training config is re-used for inference
  367. hf_deepspeed_config.del_config_sub_tree("optimizer")
  368. hf_deepspeed_config.del_config_sub_tree("lr_scheduler")
  369. optimizer, lr_scheduler = None, None
  370. model_parameters = None
  371. else:
  372. trainer.optimizer = None # important for when deepspeed_init is used as re-init
  373. deepspeed_tp_size = hf_deepspeed_config.config.get("tensor_parallel", {}).get("autotp_size", 1)
  374. if deepspeed_tp_size > 1:
  375. import deepspeed
  376. model = deepspeed.tp_model_init(
  377. model=model,
  378. tp_size=deepspeed_tp_size,
  379. dtype=hf_deepspeed_config.dtype(),
  380. config=hf_deepspeed_config.config,
  381. )
  382. model_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
  383. optimizer, lr_scheduler = deepspeed_optim_sched(
  384. trainer, hf_deepspeed_config, args, num_training_steps, model_parameters
  385. )
  386. # keep for quick debug:
  387. # from pprint import pprint; pprint(config)
  388. return optimizer, lr_scheduler
  389. def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_strict=True):
  390. # it's possible that the user is trying to resume from model_path, which doesn't necessarily
  391. # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's
  392. # a resume from a checkpoint and not just a local pretrained weight. So we check here if the
  393. # path contains what looks like a deepspeed checkpoint
  394. import glob
  395. deepspeed_checkpoint_dirs = sorted(glob.glob(f"{checkpoint_path}/global_step*"))
  396. if len(deepspeed_checkpoint_dirs) > 0:
  397. logger.info(f"Attempting to resume from {checkpoint_path}")
  398. # this magically updates self.optimizer and self.lr_scheduler
  399. load_path, _ = deepspeed_engine.load_checkpoint(
  400. checkpoint_path,
  401. load_module_strict=load_module_strict,
  402. load_optimizer_states=True,
  403. load_lr_scheduler_states=True,
  404. )
  405. if load_path is None:
  406. raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}")
  407. else:
  408. raise ValueError(f"Can't find a valid checkpoint at {checkpoint_path}")