builder.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from modelscope.metainfo import Trainers
  3. from modelscope.pipelines.builder import normalize_model_input
  4. from modelscope.pipelines.util import is_official_hub_path
  5. from modelscope.utils.config import check_config
  6. from modelscope.utils.constant import DEFAULT_MODEL_REVISION
  7. from modelscope.utils.hub import read_config
  8. from modelscope.utils.plugins import (register_modelhub_repo,
  9. register_plugins_repo)
  10. from modelscope.utils.registry import Registry, build_from_cfg
  11. TRAINERS = Registry('trainers')
  12. def build_trainer(name: str = Trainers.default, default_args: dict = None):
  13. """ build trainer given a trainer name
  14. Args:
  15. name (str, optional): Trainer name, if None, default trainer
  16. will be used.
  17. default_args (dict, optional): Default initialization arguments.
  18. """
  19. cfg = dict(type=name)
  20. model = default_args.get('model', None)
  21. model_revision = default_args.get('model_revision', DEFAULT_MODEL_REVISION)
  22. if isinstance(model, str) \
  23. or (isinstance(model, list) and isinstance(model[0], str)):
  24. if is_official_hub_path(model, revision=model_revision):
  25. # read config file from hub and parse
  26. configuration = read_config(
  27. model, revision=model_revision) if isinstance(
  28. model, str) else read_config(
  29. model[0], revision=model_revision)
  30. model_dir = normalize_model_input(model, model_revision)
  31. register_plugins_repo(configuration.safe_get('plugins'))
  32. register_modelhub_repo(model_dir,
  33. configuration.get('allow_remote', False))
  34. return build_from_cfg(cfg, TRAINERS, default_args=default_args)