utils.py 754 B

1234567891011121314151617181920212223
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from .builder import PARALLEL
  3. def is_parallel(module):
  4. """Check if a module is wrapped by parallel object.
  5. The following modules are regarded as parallel object:
  6. - torch.nn.parallel.DataParallel
  7. - torch.nn.parallel.distributed.DistributedDataParallel
  8. You may add you own parallel object by registering it to `modelscope.parallel.PARALLEL`.
  9. Args:
  10. module (nn.Module): The module to be checked.
  11. Returns:
  12. bool: True if the is wrapped by parallel object.
  13. """
  14. module_wrappers = []
  15. for group, module_dict in PARALLEL.modules.items():
  16. module_wrappers.extend(list(module_dict.values()))
  17. return isinstance(module, tuple(module_wrappers))