| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from modelscope.metainfo import Models
- from modelscope.utils.config import ConfigDict
- from modelscope.utils.constant import Tasks
- from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule
- from modelscope.utils.logger import get_logger
- from modelscope.utils.registry import Registry, build_from_cfg
- from modelscope.utils.task_utils import get_task_by_subtask_name
- logger = get_logger()
- MODELS = Registry('models')
- BACKBONES = MODELS
- HEADS = Registry('heads')
- modules = LazyImportModule.get_ast_index()[INDEX_KEY]
- for module_index in list(modules.keys()):
- if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES':
- modules[(MODELS.name.upper(), module_index[1],
- module_index[2])] = modules[module_index]
- def build_model(cfg: ConfigDict,
- task_name: str = None,
- default_args: dict = None):
- """ build model given model config dict
- Args:
- cfg (:obj:`ConfigDict`): config dict for model object.
- task_name (str, optional): task name, refer to
- :obj:`Tasks` for more details
- default_args (dict, optional): Default initialization arguments.
- """
- try:
- model = build_from_cfg(
- cfg, MODELS, group_key=task_name, default_args=default_args)
- except KeyError as e:
- # Handle subtask with a backbone model that hasn't been registered
- # All the subtask with a parent task should have a task model, otherwise it is not a
- # valid subtask
- parent_task, task_model_type = get_task_by_subtask_name(task_name)
- if task_model_type is None:
- raise KeyError(e)
- cfg['type'] = task_model_type
- model = build_from_cfg(
- cfg, MODELS, group_key=parent_task, default_args=default_args)
- return model
- def build_backbone(cfg: ConfigDict, default_args: dict = None):
- """ build backbone given backbone config dict
- Args:
- cfg (:obj:`ConfigDict`): config dict for backbone object.
- default_args (dict, optional): Default initialization arguments.
- """
- if not cfg.get('init_backbone', False):
- model_dir = cfg.pop('model_dir', None)
- else:
- model_dir = cfg.get('model_dir', None)
- try:
- model = build_from_cfg(
- cfg,
- BACKBONES,
- group_key=Tasks.backbone,
- default_args=default_args)
- except KeyError:
- # Handle backbone that is not in the register group by using transformers AutoModel.
- # AutoModel are mostly using in NLP and part of Multi-Modal, while the number of backbone in CV、Audio and MM
- # is limited, thus could be added and registered in Modelscope directly
- logger.warning(
- f'The backbone {cfg.type} is not registered in modelscope, try to import the backbone from hf transformers.'
- )
- cfg['type'] = Models.transformers
- cfg['model_dir'] = model_dir
- model = build_from_cfg(
- cfg,
- BACKBONES,
- group_key=Tasks.backbone,
- default_args=default_args)
- return model
- def build_head(cfg: ConfigDict,
- task_name: str = None,
- default_args: dict = None):
- """ build head given config dict
- Args:
- cfg (:obj:`ConfigDict`): config dict for head object.
- task_name (str, optional): task name, refer to
- :obj:`Tasks` for more details
- default_args (dict, optional): Default initialization arguments.
- """
- return build_from_cfg(
- cfg, HEADS, group_key=task_name, default_args=default_args)
|