create.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from argparse import ArgumentParser, _SubParsersAction
  3. from modelscope.cli.base import CLICommand
  4. from modelscope.hub.api import HubApi
  5. from modelscope.hub.constants import (Licenses, ModelVisibility, Visibility,
  6. VisibilityMap)
  7. from modelscope.hub.utils.aigc import AigcModel
  8. from modelscope.utils.constant import REPO_TYPE_MODEL, REPO_TYPE_SUPPORT
  9. from modelscope.utils.logger import get_logger
  10. logger = get_logger()
  11. def subparser_func(args):
  12. """ Function which will be called for a specific sub parser.
  13. """
  14. return CreateCMD(args)
  15. class CreateCMD(CLICommand):
  16. """
  17. Command for creating a new repository, supporting both model and dataset.
  18. """
  19. name = 'create'
  20. def __init__(self, args: _SubParsersAction):
  21. self.args = args
  22. @staticmethod
  23. def define_args(parsers: _SubParsersAction):
  24. parser: ArgumentParser = parsers.add_parser(CreateCMD.name)
  25. parser.add_argument(
  26. 'repo_id',
  27. type=str,
  28. help='The ID of the repo to create (e.g. `username/repo-name`)')
  29. parser.add_argument(
  30. '--token',
  31. type=str,
  32. default=None,
  33. help=
  34. 'A User Access Token generated from https://modelscope.cn/my/myaccesstoken to authenticate the user. '
  35. 'If not provided, the CLI will use the local credentials if available.'
  36. )
  37. parser.add_argument(
  38. '--repo_type',
  39. choices=REPO_TYPE_SUPPORT,
  40. default=REPO_TYPE_MODEL,
  41. help=
  42. 'Type of the repo to create (e.g. `dataset`, `model`). Default to `model`.',
  43. )
  44. parser.add_argument(
  45. '--visibility',
  46. choices=[
  47. Visibility.PUBLIC, Visibility.INTERNAL, Visibility.PRIVATE
  48. ],
  49. default=Visibility.PUBLIC,
  50. help='Visibility of the repo to create. Default to `public`.',
  51. )
  52. parser.add_argument(
  53. '--chinese_name',
  54. type=str,
  55. default=None,
  56. help='Optional, Chinese name of the repo. Default to `None`.',
  57. )
  58. parser.add_argument(
  59. '--license',
  60. type=str,
  61. choices=Licenses.to_list(),
  62. default=Licenses.APACHE_V2,
  63. help=
  64. 'Optional, License of the repo. Default to `Apache License 2.0`.',
  65. )
  66. parser.add_argument(
  67. '--endpoint',
  68. type=str,
  69. default=None,
  70. help='Optional, The modelscope server address. Default to None.',
  71. )
  72. # AIGC specific arguments
  73. aigc_group = parser.add_argument_group(
  74. 'AIGC Model Creation',
  75. 'Arguments for creating an AIGC model. Use --aigc to enable.')
  76. aigc_group.add_argument(
  77. '--aigc', action='store_true', help='Enable AIGC model creation.')
  78. aigc_group.add_argument(
  79. '--from_json',
  80. type=str,
  81. help='Path to a JSON file containing AIGC model configuration. '
  82. 'If used, all other parameters except --repo_id are ignored.')
  83. aigc_group.add_argument(
  84. '--model_path', type=str, help='Path to the model file or folder.')
  85. aigc_group.add_argument(
  86. '--aigc_type',
  87. type=str,
  88. help="AIGC type. Recommended: 'Checkpoint', 'LoRA', 'VAE'.")
  89. aigc_group.add_argument(
  90. '--base_model_type',
  91. type=str,
  92. help='Base model type, e.g., SD_XL.')
  93. aigc_group.add_argument(
  94. '--revision',
  95. type=str,
  96. default='v1.0',
  97. help="Model revision. Defaults to 'v1.0'.")
  98. aigc_group.add_argument(
  99. '--base_model_id',
  100. type=str,
  101. default='',
  102. help='Base model ID from ModelScope.')
  103. aigc_group.add_argument(
  104. '--description',
  105. type=str,
  106. default='This is an AIGC model.',
  107. help='Model description.')
  108. aigc_group.add_argument(
  109. '--path_in_repo',
  110. type=str,
  111. default='',
  112. help='Path in the repository to upload to.')
  113. aigc_group.add_argument(
  114. '--model_source',
  115. type=str,
  116. default='USER_UPLOAD',
  117. help=
  118. 'Source of the AIGC model. `USER_UPLOAD`, `TRAINED_FROM_MODELSCOPE` or `TRAINED_FROM_ALIYUN_FC`.'
  119. )
  120. aigc_group.add_argument(
  121. '--base_model_sub_type',
  122. type=str,
  123. default='',
  124. help='Base model sub type, e.g., Qwen_Edit_2509')
  125. parser.set_defaults(func=subparser_func)
  126. def execute(self):
  127. if self.args.aigc:
  128. if self.args.repo_type != REPO_TYPE_MODEL:
  129. raise ValueError(
  130. 'AIGC models can only be created when repo_type is "model".'
  131. )
  132. self._create_aigc_model()
  133. else:
  134. self._create_regular_repo()
  135. def _create_regular_repo(self):
  136. # Check token and login
  137. # The cookies will be reused if the user has logged in before.
  138. api = HubApi(endpoint=self.args.endpoint)
  139. # Create repo
  140. api.create_repo(
  141. repo_id=self.args.repo_id,
  142. token=self.args.token,
  143. visibility=self.args.visibility,
  144. repo_type=self.args.repo_type,
  145. chinese_name=self.args.chinese_name,
  146. license=self.args.license,
  147. exist_ok=True,
  148. create_default_config=True,
  149. endpoint=self.args.endpoint,
  150. )
  151. def _create_aigc_model(self):
  152. """Execute the command."""
  153. api = HubApi(endpoint=self.args.endpoint)
  154. model_id = self.args.repo_id
  155. if self.args.from_json:
  156. # Create from JSON file
  157. logger.info('Creating AIGC model from JSON file: '
  158. f'{self.args.from_json}')
  159. aigc_model = AigcModel.from_json_file(self.args.from_json)
  160. else:
  161. # Create from command line arguments
  162. logger.info('Creating AIGC model from command line arguments...')
  163. if not all([
  164. self.args.model_path, self.args.aigc_type,
  165. self.args.base_model_type
  166. ]):
  167. raise ValueError(
  168. 'Error: --model_path, --aigc_type, and '
  169. '--base_model_type are required when not using '
  170. '--from_json.')
  171. aigc_model = AigcModel(
  172. model_path=self.args.model_path,
  173. aigc_type=self.args.aigc_type,
  174. base_model_type=self.args.base_model_type,
  175. tag=self.args.revision,
  176. description=self.args.description,
  177. base_model_id=self.args.base_model_id,
  178. path_in_repo=self.args.path_in_repo,
  179. model_source=self.args.model_source,
  180. base_model_sub_type=self.args.base_model_sub_type,
  181. )
  182. # Convert visibility string to int for the API call
  183. reverse_visibility_map = {v: k for k, v in VisibilityMap.items()}
  184. visibility_idx: int = reverse_visibility_map.get(
  185. self.args.visibility, ModelVisibility.PUBLIC)
  186. try:
  187. model_url = api.create_model(
  188. model_id=model_id,
  189. token=self.args.token,
  190. visibility=visibility_idx,
  191. license=self.args.license,
  192. chinese_name=self.args.chinese_name,
  193. aigc_model=aigc_model)
  194. print(f'Successfully created AIGC model: {model_url}')
  195. except Exception as e:
  196. print(f'Error creating AIGC model: {e}')