# Copyright (c) Alibaba, Inc. and its affiliates. from .builder import PARALLEL def is_parallel(module): """Check if a module is wrapped by parallel object. The following modules are regarded as parallel object: - torch.nn.parallel.DataParallel - torch.nn.parallel.distributed.DistributedDataParallel You may add you own parallel object by registering it to `modelscope.parallel.PARALLEL`. Args: module (nn.Module): The module to be checked. Returns: bool: True if the is wrapped by parallel object. """ module_wrappers = [] for group, module_dict in PARALLEL.modules.items(): module_wrappers.extend(list(module_dict.values())) return isinstance(module, tuple(module_wrappers))