# Copyright (c) Alibaba, Inc. and its affiliates. from modelscope.metainfo import Models from modelscope.utils.config import ConfigDict from modelscope.utils.constant import Tasks from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule from modelscope.utils.logger import get_logger from modelscope.utils.registry import Registry, build_from_cfg from modelscope.utils.task_utils import get_task_by_subtask_name logger = get_logger() MODELS = Registry('models') BACKBONES = MODELS HEADS = Registry('heads') modules = LazyImportModule.get_ast_index()[INDEX_KEY] for module_index in list(modules.keys()): if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES': modules[(MODELS.name.upper(), module_index[1], module_index[2])] = modules[module_index] def build_model(cfg: ConfigDict, task_name: str = None, default_args: dict = None): """ build model given model config dict Args: cfg (:obj:`ConfigDict`): config dict for model object. task_name (str, optional): task name, refer to :obj:`Tasks` for more details default_args (dict, optional): Default initialization arguments. """ try: model = build_from_cfg( cfg, MODELS, group_key=task_name, default_args=default_args) except KeyError as e: # Handle subtask with a backbone model that hasn't been registered # All the subtask with a parent task should have a task model, otherwise it is not a # valid subtask parent_task, task_model_type = get_task_by_subtask_name(task_name) if task_model_type is None: raise KeyError(e) cfg['type'] = task_model_type model = build_from_cfg( cfg, MODELS, group_key=parent_task, default_args=default_args) return model def build_backbone(cfg: ConfigDict, default_args: dict = None): """ build backbone given backbone config dict Args: cfg (:obj:`ConfigDict`): config dict for backbone object. default_args (dict, optional): Default initialization arguments. """ if not cfg.get('init_backbone', False): model_dir = cfg.pop('model_dir', None) else: model_dir = cfg.get('model_dir', None) try: model = build_from_cfg( cfg, BACKBONES, group_key=Tasks.backbone, default_args=default_args) except KeyError: # Handle backbone that is not in the register group by using transformers AutoModel. # AutoModel are mostly using in NLP and part of Multi-Modal, while the number of backbone in CV、Audio and MM # is limited, thus could be added and registered in Modelscope directly logger.warning( f'The backbone {cfg.type} is not registered in modelscope, try to import the backbone from hf transformers.' ) cfg['type'] = Models.transformers cfg['model_dir'] = model_dir model = build_from_cfg( cfg, BACKBONES, group_key=Tasks.backbone, default_args=default_args) return model def build_head(cfg: ConfigDict, task_name: str = None, default_args: dict = None): """ build head given config dict Args: cfg (:obj:`ConfigDict`): config dict for head object. task_name (str, optional): task name, refer to :obj:`Tasks` for more details default_args (dict, optional): Default initialization arguments. """ return build_from_cfg( cfg, HEADS, group_key=task_name, default_args=default_args)