# Copyright (c) Alibaba, Inc. and its affiliates. import importlib import inspect from typing import List, Tuple, Union from modelscope.utils.logger import get_logger TYPE_NAME = 'type' default_group = 'default' logger = get_logger() class Registry(object): """ Registry which support registering modules and group them by a keyname If group name is not provided, modules will be registered to default group. """ def __init__(self, name: str): self._name = name self._modules = {default_group: {}} def __repr__(self): format_str = self.__class__.__name__ + f' ({self._name})\n' for group_name, group in self._modules.items(): format_str += f'group_name={group_name}, '\ f'modules={list(group.keys())}\n' return format_str @property def name(self): return self._name @property def modules(self): return self._modules def list(self): """ logging the list of module in current registry """ for group_name, group in self._modules.items(): logger.info(f'group_name={group_name}') for m in group.keys(): logger.info(f'\t{m}') logger.info('') def get(self, module_key, group_key=default_group): if group_key not in self._modules: return None else: return self._modules[group_key].get(module_key, None) def _register_module(self, group_key=default_group, module_name=None, module_cls=None, force=False): assert isinstance(group_key, str), 'group_key is required and must be str' if group_key not in self._modules: self._modules[group_key] = dict() # Some registered module_cls can be function type. # if not inspect.isclass(module_cls): # raise TypeError(f'module is not a class type: {type(module_cls)}') if module_name is None: module_name = module_cls.__name__ if module_name in self._modules[group_key] and not force: raise KeyError(f'{module_name} is already registered in ' f'{self._name}[{group_key}]') self._modules[group_key][module_name] = module_cls module_cls.group_key = group_key def register_module(self, group_key: str = default_group, module_name: str = None, module_cls: type = None, force=False): """ Register module Example: >>> models = Registry('models') >>> @models.register_module('image-classification', 'SwinT') >>> class SwinTransformer: >>> pass >>> @models.register_module('SwinDefault') >>> class SwinTransformerDefaultGroup: >>> pass >>> class SwinTransformer2: >>> pass >>> MODELS.register_module('image-classification', module_name='SwinT2', module_cls=SwinTransformer2) Args: group_key: Group name of which module will be registered, default group name is 'default' module_name: Module name module_cls: Module class object force (bool, optional): Whether to override an existing class with the same name. Default: False. """ if not (module_name is None or isinstance(module_name, str)): raise TypeError(f'module_name must be either of None, str,' f'got {type(module_name)}') if module_cls is not None: self._register_module( group_key=group_key, module_name=module_name, module_cls=module_cls, force=force) return module_cls # if module_cls is None, should return a decorator function def _register(module_cls): self._register_module( group_key=group_key, module_name=module_name, module_cls=module_cls, force=force) return module_cls return _register def build_from_cfg(cfg, registry: Registry, group_key: str = default_group, default_args: dict = None) -> object: """Build a module from config dict when it is a class configuration, or call a function from config dict when it is a function configuration. Example: >>> models = Registry('models') >>> @models.register_module('image-classification', 'SwinT') >>> class SwinTransformer: >>> pass >>> swint = build_from_cfg(dict(type='SwinT'), MODELS, >>> 'image-classification') >>> # Returns an instantiated object >>> >>> @MODELS.register_module() >>> def swin_transformer(): >>> pass >>> = build_from_cfg(dict(type='swin_transformer'), MODELS) >>> # Return a result of the calling function Args: cfg (dict): Config dict. It should at least contain the key "type". registry (:obj:`Registry`): The registry to search the type from. group_key (str, optional): The name of registry group from which module should be searched. default_args (dict, optional): Default initialization arguments. type_name (str, optional): The name of the type in the config. Returns: object: The constructed object. """ if not isinstance(cfg, dict): raise TypeError(f'cfg must be a dict, but got {type(cfg)}') if TYPE_NAME not in cfg: if default_args is None or TYPE_NAME not in default_args: raise KeyError( f'`cfg` or `default_args` must contain the key "{TYPE_NAME}", ' f'but got {cfg}\n{default_args}') if not isinstance(registry, Registry): raise TypeError('registry must be an modelscope.Registry object, ' f'but got {type(registry)}') if not (isinstance(default_args, dict) or default_args is None): raise TypeError('default_args must be a dict or None, ' f'but got {type(default_args)}') # dynamic load installation requirements for this module from modelscope.utils.import_utils import LazyImportModule sig = (registry.name.upper(), group_key, cfg['type']) LazyImportModule.import_module(sig) args = cfg.copy() if default_args is not None: for name, value in default_args.items(): args.setdefault(name, value) if group_key is None: group_key = default_group obj_type = args.pop(TYPE_NAME) if isinstance(obj_type, str): obj_cls = registry.get(obj_type, group_key=group_key) if obj_cls is None: raise KeyError( f'{obj_type} is not in the {registry.name}' f' registry group {group_key}. Please make' f' sure the correct version of ModelScope library is used.') obj_cls.group_key = group_key elif inspect.isclass(obj_type) or inspect.isfunction(obj_type): obj_cls = obj_type else: raise TypeError( f'type must be a str or valid type, but got {type(obj_type)}') try: if hasattr(obj_cls, '_instantiate'): return obj_cls._instantiate(**args) else: return obj_cls(**args) except Exception as e: # Normal TypeError does not print class name. raise type(e)(f'{obj_cls.__name__}: {e}') from e