local_sgd.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # Copyright 2023 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. from accelerate import Accelerator, DistributedType
  16. class LocalSGD:
  17. """
  18. A helper class to support local SGD on top of Accelerator. It simply runs a given number of updates independently
  19. on each device, and averages model weights every K synchronization step.
  20. It should be used only in the multi-GPU (or multi-CPU) setup without extensions such as DeepSpeed. In particular,
  21. this is a simple implementation that cannot support scenarios such as model parallelism.
  22. Although we are not aware of the true origins of this simple approach, the idea of local SGD is quite old and goes
  23. back to at least:
  24. Zhang, J., De Sa, C., Mitliagkas, I., & Ré, C. (2016). [Parallel SGD: When does averaging help?. arXiv preprint
  25. arXiv:1606.07365.](https://huggingface.co/papers/1606.07365)
  26. We credit the term Local SGD to the following paper (but there might be earlier references we are not aware of).
  27. Stich, Sebastian Urban. ["Local SGD Converges Fast and Communicates Little." ICLR 2019-International Conference on
  28. Learning Representations. No. CONF. 2019.](https://huggingface.co/papers/1805.09767)
  29. """
  30. def __enter__(self):
  31. if self.enabled:
  32. self.model_sync_obj = self.model.no_sync()
  33. self.model_sync_obj.__enter__()
  34. return self
  35. def __exit__(self, type, value, tb):
  36. if self.enabled:
  37. # Average all models on exit
  38. self._sync_and_avg_model_params()
  39. self.model_sync_obj.__exit__(type, value, tb)
  40. def __init__(self, accelerator: Accelerator, model: torch.nn.Module, local_sgd_steps: int, enabled: bool = True):
  41. """
  42. Constructor.
  43. Args:
  44. model (`torch.nn.Module):
  45. The model whose parameters we need to average.
  46. accelerator (`Accelerator`):
  47. Accelerator object.
  48. local_sgd_steps (`int`):
  49. A number of local SGD steps (before model parameters are synchronized).
  50. enabled (`bool):
  51. Local SGD is disabled if this parameter set to `False`.
  52. """
  53. if accelerator.distributed_type not in [
  54. DistributedType.NO,
  55. DistributedType.MULTI_CPU,
  56. DistributedType.MULTI_GPU,
  57. DistributedType.MULTI_XPU,
  58. DistributedType.MULTI_MLU,
  59. DistributedType.MULTI_HPU,
  60. DistributedType.MULTI_SDAA,
  61. DistributedType.MULTI_MUSA,
  62. DistributedType.MULTI_NPU,
  63. ]:
  64. raise NotImplementedError("LocalSGD is supported only for CPUs and GPUs (no DeepSpeed or MegatronLM)")
  65. self.enabled = enabled and accelerator.distributed_type != DistributedType.NO
  66. self.num_steps = 0
  67. if self.enabled:
  68. self.accelerator = accelerator
  69. self.model = model
  70. self.local_sgd_steps = local_sgd_steps
  71. def step(self):
  72. """
  73. This function makes a "step" and synchronizes model parameters if necessary.
  74. """
  75. self.num_steps += 1
  76. if not self.enabled:
  77. return
  78. if self.num_steps % self.local_sgd_steps == 0:
  79. self._sync_and_avg_model_params()
  80. def _sync_and_avg_model_params(self):
  81. """
  82. Synchronize + Average model parameters across all GPUs
  83. """
  84. self.accelerator.wait_for_everyone()
  85. with self.accelerator.autocast():
  86. for param in self.model.parameters():
  87. param.data = self.accelerator.reduce(param.data, reduction="mean")