ddp_hook.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from modelscope.metainfo import Hooks
  3. from modelscope.trainers.hooks.builder import HOOKS
  4. from modelscope.trainers.hooks.hook import Hook
  5. from modelscope.trainers.hooks.priority import Priority
  6. from modelscope.utils.constant import DistributedParallelType
  7. from modelscope.utils.device import create_device
  8. from modelscope.utils.torch_utils import get_local_rank, init_dist
  9. @HOOKS.register_module(module_name=Hooks.DDPHook)
  10. class DDPHook(Hook):
  11. PRIORITY = Priority.LOW
  12. def __init__(self, launcher):
  13. """The DDP Hook for data parallel
  14. Args:
  15. launcher(str, required): The launcher info, can be 'pytorch' or 'mpi' or 'slurm'
  16. """
  17. assert launcher is not None
  18. self.launcher = launcher
  19. self.wrapped = False
  20. # TODO support single GPU evaluate & multi GPU train
  21. def after_init(self, trainer):
  22. init_dist(self.launcher)
  23. local_rank = get_local_rank()
  24. trainer.device = create_device(f'cuda:{local_rank}')
  25. trainer.model.to(trainer.device)
  26. trainer.parallel_groups[DistributedParallelType.DP] = None
  27. def before_run(self, trainer):
  28. self.wrap_module(trainer)
  29. def before_val(self, trainer):
  30. self.wrap_module(trainer)
  31. def wrap_module(self, trainer):
  32. if not self.wrapped:
  33. trainer.model = trainer.to_parallel(trainer.model)
  34. self.wrapped = True