builder.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from modelscope.metainfo import Models
  3. from modelscope.utils.config import ConfigDict
  4. from modelscope.utils.constant import Tasks
  5. from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule
  6. from modelscope.utils.logger import get_logger
  7. from modelscope.utils.registry import Registry, build_from_cfg
  8. from modelscope.utils.task_utils import get_task_by_subtask_name
  9. logger = get_logger()
  10. MODELS = Registry('models')
  11. BACKBONES = MODELS
  12. HEADS = Registry('heads')
  13. modules = LazyImportModule.get_ast_index()[INDEX_KEY]
  14. for module_index in list(modules.keys()):
  15. if module_index[1] == Tasks.backbone and module_index[0] == 'BACKBONES':
  16. modules[(MODELS.name.upper(), module_index[1],
  17. module_index[2])] = modules[module_index]
  18. def build_model(cfg: ConfigDict,
  19. task_name: str = None,
  20. default_args: dict = None):
  21. """ build model given model config dict
  22. Args:
  23. cfg (:obj:`ConfigDict`): config dict for model object.
  24. task_name (str, optional): task name, refer to
  25. :obj:`Tasks` for more details
  26. default_args (dict, optional): Default initialization arguments.
  27. """
  28. try:
  29. model = build_from_cfg(
  30. cfg, MODELS, group_key=task_name, default_args=default_args)
  31. except KeyError as e:
  32. # Handle subtask with a backbone model that hasn't been registered
  33. # All the subtask with a parent task should have a task model, otherwise it is not a
  34. # valid subtask
  35. parent_task, task_model_type = get_task_by_subtask_name(task_name)
  36. if task_model_type is None:
  37. raise KeyError(e)
  38. cfg['type'] = task_model_type
  39. model = build_from_cfg(
  40. cfg, MODELS, group_key=parent_task, default_args=default_args)
  41. return model
  42. def build_backbone(cfg: ConfigDict, default_args: dict = None):
  43. """ build backbone given backbone config dict
  44. Args:
  45. cfg (:obj:`ConfigDict`): config dict for backbone object.
  46. default_args (dict, optional): Default initialization arguments.
  47. """
  48. if not cfg.get('init_backbone', False):
  49. model_dir = cfg.pop('model_dir', None)
  50. else:
  51. model_dir = cfg.get('model_dir', None)
  52. try:
  53. model = build_from_cfg(
  54. cfg,
  55. BACKBONES,
  56. group_key=Tasks.backbone,
  57. default_args=default_args)
  58. except KeyError:
  59. # Handle backbone that is not in the register group by using transformers AutoModel.
  60. # AutoModel are mostly using in NLP and part of Multi-Modal, while the number of backbone in CV、Audio and MM
  61. # is limited, thus could be added and registered in Modelscope directly
  62. logger.warning(
  63. f'The backbone {cfg.type} is not registered in modelscope, try to import the backbone from hf transformers.'
  64. )
  65. cfg['type'] = Models.transformers
  66. cfg['model_dir'] = model_dir
  67. model = build_from_cfg(
  68. cfg,
  69. BACKBONES,
  70. group_key=Tasks.backbone,
  71. default_args=default_args)
  72. return model
  73. def build_head(cfg: ConfigDict,
  74. task_name: str = None,
  75. default_args: dict = None):
  76. """ build head given config dict
  77. Args:
  78. cfg (:obj:`ConfigDict`): config dict for head object.
  79. task_name (str, optional): task name, refer to
  80. :obj:`Tasks` for more details
  81. default_args (dict, optional): Default initialization arguments.
  82. """
  83. return build_from_cfg(
  84. cfg, HEADS, group_key=task_name, default_args=default_args)