distributed.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
  2. # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import math
  16. import torch
  17. import torch.distributed as dist
  18. from megatron_util import mpu
  19. from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
  20. from torch.autograd import Variable
  21. from torch.nn.modules import Module
  22. def normal_init_method(mean, std):
  23. def init_(tensor):
  24. return torch.nn.init.normal_(tensor, mean=mean, std=std)
  25. return init_
  26. def scaled_init_method(mean, std, num_layers):
  27. """Init method based on N(0, sigma/sqrt(2*num_layers)."""
  28. std = std / math.sqrt(2.0 * num_layers)
  29. def init_(tensor):
  30. return torch.nn.init.normal_(tensor, mean=mean, std=std)
  31. return init_
  32. class DistributedDataParallel(Module):
  33. def __init__(self, module):
  34. super(DistributedDataParallel, self).__init__()
  35. self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
  36. self.module = module
  37. self.data_parallel_group = mpu.get_data_parallel_group()
  38. src_rank = mpu.get_tensor_model_parallel_rank()
  39. for p in self.module.parameters():
  40. if torch.is_tensor(p):
  41. dist.broadcast(p, src_rank, group=self.data_parallel_group)
  42. def allreduce_params(reduce_after=True,
  43. no_scale=False,
  44. fp32_allreduce=False):
  45. if (self.needs_reduction):
  46. self.needs_reduction = False
  47. buckets = {}
  48. for name, param in self.module.named_parameters():
  49. if param.requires_grad and param.grad is not None:
  50. tp = (param.data.type())
  51. if tp not in buckets:
  52. buckets[tp] = []
  53. buckets[tp].append(param)
  54. if self.warn_on_half:
  55. if torch.cuda.HalfTensor in buckets:
  56. print(
  57. 'WARNING: gloo dist backend for half parameters may be extremely slow.',
  58. 'It is recommended to use the NCCL backend in this case.'
  59. )
  60. self.warn_on_half = False
  61. for tp in buckets:
  62. bucket = buckets[tp]
  63. grads = [param.grad.data for param in bucket]
  64. coalesced = _flatten_dense_tensors(grads)
  65. if fp32_allreduce:
  66. coalesced = coalesced.float()
  67. if not no_scale and not reduce_after:
  68. coalesced /= dist.get_world_size(
  69. group=self.data_parallel_group)
  70. dist.all_reduce(coalesced, group=self.data_parallel_group)
  71. torch.cuda.synchronize()
  72. if not no_scale and reduce_after:
  73. coalesced /= dist.get_world_size(
  74. group=self.data_parallel_group)
  75. for buf, synced in zip(
  76. grads, _unflatten_dense_tensors(coalesced, grads)):
  77. buf.copy_(synced)
  78. self.hook_handles = []
  79. self.hooks = []
  80. for param in list(self.module.parameters()):
  81. def allreduce_hook(*unused):
  82. Variable._execution_engine.queue_callback(allreduce_params)
  83. self.allreduce_params = allreduce_params
  84. def forward(self, *inputs, **kwargs):
  85. self.needs_reduction = True
  86. return self.module(*inputs, **kwargs)
  87. def state_dict(self, destination=None, prefix='', keep_vars=False):
  88. sd = self.module.state_dict(destination, prefix, keep_vars)
  89. return sd
  90. def load_state_dict(self, state_dict, strict=True):
  91. self.module.load_state_dict(state_dict, strict=strict)