| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from argparse import ArgumentParser, _SubParsersAction
- from modelscope.cli.base import CLICommand
- from modelscope.hub.api import HubApi
- from modelscope.hub.constants import (Licenses, ModelVisibility, Visibility,
- VisibilityMap)
- from modelscope.hub.utils.aigc import AigcModel
- from modelscope.utils.constant import REPO_TYPE_MODEL, REPO_TYPE_SUPPORT
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- def subparser_func(args):
- """ Function which will be called for a specific sub parser.
- """
- return CreateCMD(args)
- class CreateCMD(CLICommand):
- """
- Command for creating a new repository, supporting both model and dataset.
- """
- name = 'create'
- def __init__(self, args: _SubParsersAction):
- self.args = args
- @staticmethod
- def define_args(parsers: _SubParsersAction):
- parser: ArgumentParser = parsers.add_parser(CreateCMD.name)
- parser.add_argument(
- 'repo_id',
- type=str,
- help='The ID of the repo to create (e.g. `username/repo-name`)')
- parser.add_argument(
- '--token',
- type=str,
- default=None,
- help=
- 'A User Access Token generated from https://modelscope.cn/my/myaccesstoken to authenticate the user. '
- 'If not provided, the CLI will use the local credentials if available.'
- )
- parser.add_argument(
- '--repo_type',
- choices=REPO_TYPE_SUPPORT,
- default=REPO_TYPE_MODEL,
- help=
- 'Type of the repo to create (e.g. `dataset`, `model`). Default to `model`.',
- )
- parser.add_argument(
- '--visibility',
- choices=[
- Visibility.PUBLIC, Visibility.INTERNAL, Visibility.PRIVATE
- ],
- default=Visibility.PUBLIC,
- help='Visibility of the repo to create. Default to `public`.',
- )
- parser.add_argument(
- '--chinese_name',
- type=str,
- default=None,
- help='Optional, Chinese name of the repo. Default to `None`.',
- )
- parser.add_argument(
- '--license',
- type=str,
- choices=Licenses.to_list(),
- default=Licenses.APACHE_V2,
- help=
- 'Optional, License of the repo. Default to `Apache License 2.0`.',
- )
- parser.add_argument(
- '--endpoint',
- type=str,
- default=None,
- help='Optional, The modelscope server address. Default to None.',
- )
- # AIGC specific arguments
- aigc_group = parser.add_argument_group(
- 'AIGC Model Creation',
- 'Arguments for creating an AIGC model. Use --aigc to enable.')
- aigc_group.add_argument(
- '--aigc', action='store_true', help='Enable AIGC model creation.')
- aigc_group.add_argument(
- '--from_json',
- type=str,
- help='Path to a JSON file containing AIGC model configuration. '
- 'If used, all other parameters except --repo_id are ignored.')
- aigc_group.add_argument(
- '--model_path', type=str, help='Path to the model file or folder.')
- aigc_group.add_argument(
- '--aigc_type',
- type=str,
- help="AIGC type. Recommended: 'Checkpoint', 'LoRA', 'VAE'.")
- aigc_group.add_argument(
- '--base_model_type',
- type=str,
- help='Base model type, e.g., SD_XL.')
- aigc_group.add_argument(
- '--revision',
- type=str,
- default='v1.0',
- help="Model revision. Defaults to 'v1.0'.")
- aigc_group.add_argument(
- '--base_model_id',
- type=str,
- default='',
- help='Base model ID from ModelScope.')
- aigc_group.add_argument(
- '--description',
- type=str,
- default='This is an AIGC model.',
- help='Model description.')
- aigc_group.add_argument(
- '--path_in_repo',
- type=str,
- default='',
- help='Path in the repository to upload to.')
- aigc_group.add_argument(
- '--model_source',
- type=str,
- default='USER_UPLOAD',
- help=
- 'Source of the AIGC model. `USER_UPLOAD`, `TRAINED_FROM_MODELSCOPE` or `TRAINED_FROM_ALIYUN_FC`.'
- )
- aigc_group.add_argument(
- '--base_model_sub_type',
- type=str,
- default='',
- help='Base model sub type, e.g., Qwen_Edit_2509')
- parser.set_defaults(func=subparser_func)
- def execute(self):
- if self.args.aigc:
- if self.args.repo_type != REPO_TYPE_MODEL:
- raise ValueError(
- 'AIGC models can only be created when repo_type is "model".'
- )
- self._create_aigc_model()
- else:
- self._create_regular_repo()
- def _create_regular_repo(self):
- # Check token and login
- # The cookies will be reused if the user has logged in before.
- api = HubApi(endpoint=self.args.endpoint)
- # Create repo
- api.create_repo(
- repo_id=self.args.repo_id,
- token=self.args.token,
- visibility=self.args.visibility,
- repo_type=self.args.repo_type,
- chinese_name=self.args.chinese_name,
- license=self.args.license,
- exist_ok=True,
- create_default_config=True,
- endpoint=self.args.endpoint,
- )
- def _create_aigc_model(self):
- """Execute the command."""
- api = HubApi(endpoint=self.args.endpoint)
- model_id = self.args.repo_id
- if self.args.from_json:
- # Create from JSON file
- logger.info('Creating AIGC model from JSON file: '
- f'{self.args.from_json}')
- aigc_model = AigcModel.from_json_file(self.args.from_json)
- else:
- # Create from command line arguments
- logger.info('Creating AIGC model from command line arguments...')
- if not all([
- self.args.model_path, self.args.aigc_type,
- self.args.base_model_type
- ]):
- raise ValueError(
- 'Error: --model_path, --aigc_type, and '
- '--base_model_type are required when not using '
- '--from_json.')
- aigc_model = AigcModel(
- model_path=self.args.model_path,
- aigc_type=self.args.aigc_type,
- base_model_type=self.args.base_model_type,
- tag=self.args.revision,
- description=self.args.description,
- base_model_id=self.args.base_model_id,
- path_in_repo=self.args.path_in_repo,
- model_source=self.args.model_source,
- base_model_sub_type=self.args.base_model_sub_type,
- )
- # Convert visibility string to int for the API call
- reverse_visibility_map = {v: k for k, v in VisibilityMap.items()}
- visibility_idx: int = reverse_visibility_map.get(
- self.args.visibility, ModelVisibility.PUBLIC)
- try:
- model_url = api.create_model(
- model_id=model_id,
- token=self.args.token,
- visibility=visibility_idx,
- license=self.args.license,
- chinese_name=self.args.chinese_name,
- aigc_model=aigc_model)
- print(f'Successfully created AIGC model: {model_url}')
- except Exception as e:
- print(f'Error creating AIGC model: {e}')
|