builder.py 681 B

1234567891011121314151617181920
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from torch.nn.parallel.distributed import DistributedDataParallel
  3. from modelscope.utils.config import ConfigDict
  4. from modelscope.utils.registry import Registry, build_from_cfg
  5. PARALLEL = Registry('parallel')
  6. PARALLEL.register_module(
  7. module_name='DistributedDataParallel', module_cls=DistributedDataParallel)
  8. def build_parallel(cfg: ConfigDict, default_args: dict = None):
  9. """ build parallel
  10. Args:
  11. cfg (:obj:`ConfigDict`): config dict for parallel object.
  12. default_args (dict, optional): Default initialization arguments.
  13. """
  14. return build_from_cfg(cfg, PARALLEL, default_args=default_args)