cmdline.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652
  1. #!/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. 命令行
  10. Authors: xiangyiqing(xiangyiqing@baidu.com),suoyi@baidu.com
  11. Date: 2024/03/05
  12. """
  13. import sys
  14. import argparse
  15. import click
  16. import os
  17. from aistudio_sdk import log
  18. from aistudio_sdk.sdk import pipeline
  19. from aistudio_sdk.file_download import model_file_download, file_download
  20. from aistudio_sdk.snapshot_download import snapshot_download
  21. from aistudio_sdk.utils.util import convert_patterns
  22. from aistudio_sdk.config import (DEFAULT_MAX_WORKERS, REPO_TYPE_SUPPORT, REPO_TYPE_MODEL,
  23. DEFAULT_DATASET_REVISION, REPO_TYPE_DATASET)
  24. from aistudio_sdk.hub import upload_file, upload_folder
  25. __all__ = [
  26. 'main',
  27. ]
  28. class CustomHelpFormatter(argparse.RawTextHelpFormatter):
  29. """
  30. 自定义帮助信息格式
  31. """
  32. pass
  33. def init():
  34. """
  35. 构建CLI Parser
  36. """
  37. log.cli_log()
  38. parser = argparse.ArgumentParser(prog='PROG', formatter_class=CustomHelpFormatter)
  39. subparser_aistudio = parser.add_subparsers(
  40. help='AI Studio CLI SDK',
  41. dest='command'
  42. )
  43. # config 子命令,用于身份认证和日志级别设置
  44. # 用法示例:
  45. # aistudio config -t <token> -l info
  46. config = subparser_aistudio.add_parser(
  47. 'config',
  48. help='首次使用AI Studio CLI管理任务时, 需要先使用AI Studio账号的访问令牌进行身份认证。\
  49. 一次认证后,再次使用时无需认证。'
  50. )
  51. config.add_argument(
  52. '-t', '--token',
  53. type=str,
  54. required=False,
  55. default='',
  56. help='AI Studio账号的访问令牌'
  57. )
  58. config.add_argument(
  59. '-l', '--log',
  60. type=str,
  61. required=False,
  62. default='',
  63. choices=['info', 'debug', ''],
  64. help='日志级别'
  65. )
  66. # submit 子命令,用于提交SDK产线任务
  67. # 用法示例:
  68. # aistudio submit job -n <name> -p <path> -c <cmd> -e <env> -d <device> -g <gpus> -pay <payment> -m <mount_dataset>
  69. submit = subparser_aistudio.add_parser(
  70. 'submit',
  71. help='提交SDK产线任务'
  72. )
  73. subparser_submit = submit.add_subparsers()
  74. # submit job 子命令及其参数
  75. submit_job = subparser_submit.add_parser(
  76. 'job',
  77. help='提交SDK产线任务'
  78. )
  79. submit_job.add_argument(
  80. '-n', '--name',
  81. type=str,
  82. required=True,
  83. dest='summit_name',
  84. help='产线任务名称'
  85. )
  86. submit_job.add_argument(
  87. '-p', '--path',
  88. type=str,
  89. required=True,
  90. help='代码包本地路径(文件夹),要求文件总体积不超过50MB'
  91. )
  92. submit_job.add_argument(
  93. '-c', '--cmd',
  94. type=str,
  95. required=True,
  96. help='任务启动命令'
  97. )
  98. submit_job.add_argument(
  99. '-e', '--env',
  100. type=str,
  101. required=False,
  102. default='paddle2.6_py3.10',
  103. choices=['paddle2.4_py3.7', 'paddle2.5_py3.10', 'paddle2.6_py3.10', 'paddle3.0_py3.10'],
  104. help='飞桨框架版本, 默认paddle2.6_py3.10'
  105. )
  106. submit_job.add_argument(
  107. '-d', '--device',
  108. type=str,
  109. required=False,
  110. default='v100',
  111. choices=['v100'],
  112. help='硬件资源, 默认v100'
  113. )
  114. submit_job.add_argument(
  115. '-g', '--gpus',
  116. type=int,
  117. required=False,
  118. default='1',
  119. choices=[1, 4, 8],
  120. help='gpu数量, 默认单卡'
  121. )
  122. submit_job.add_argument(
  123. '-pay', '--payment',
  124. type=str,
  125. required=False,
  126. default='acoin',
  127. choices=['acoin', 'coupon'],
  128. help='计费方式: * acoin-A币 * coupon-算力点. 默认使用A币'
  129. )
  130. submit_job.add_argument(
  131. '-m', '--mount_dataset',
  132. action='append',
  133. type=int,
  134. required=False,
  135. default=[],
  136. help='数据集挂载, 单个任务最多挂载3个'
  137. )
  138. # jobs 子命令,用于查询SDK产线任务
  139. # 用法示例:
  140. # aistudio jobs <query_pipeline_id> -n <name> -s <status>
  141. jobs = subparser_aistudio.add_parser(
  142. 'jobs',
  143. help='查询SDK产线任务'
  144. )
  145. jobs.add_argument(
  146. 'query_pipeline_id',
  147. type=str,
  148. nargs='?',
  149. default='',
  150. help='产线id'
  151. )
  152. jobs.add_argument(
  153. '-n', '--name',
  154. type=str,
  155. required=False,
  156. default='',
  157. help='产线名称'
  158. )
  159. jobs.add_argument(
  160. '-s', '--status',
  161. type=str,
  162. required=False,
  163. default='',
  164. help='状态'
  165. )
  166. # stop 子命令,用于停止SDK产线任务
  167. # 用法示例:
  168. # aistudio stop job <stop_pipeline_id> -f
  169. stop = subparser_aistudio.add_parser(
  170. 'stop',
  171. help='停止SDK产线任务'
  172. )
  173. subparser_stop = stop.add_subparsers()
  174. # stop job 子命令及其参数
  175. stop_job = subparser_stop.add_parser(
  176. 'job',
  177. help='停止SDK产线任务'
  178. )
  179. stop_job.add_argument(
  180. 'stop_pipeline_id',
  181. type=str,
  182. help='产线id'
  183. )
  184. stop_job.add_argument(
  185. '-f', '--force',
  186. action='store_true',
  187. help='强制停止,无需二次确认'
  188. )
  189. # 创建主命令解析器
  190. job = subparser_aistudio.add_parser(
  191. 'job',
  192. help='管理SDK产线任务'
  193. )
  194. # 添加 'job_id' 参数
  195. job.add_argument(
  196. 'job_id',
  197. type=str,
  198. help='任务ID'
  199. )
  200. # 创建job子命令的解析器
  201. subparser_job = job.add_subparsers(dest='command', required=True, help='job子命令')
  202. # 'ls' 子命令,用于查询 output 目录下的文件
  203. job_ls = subparser_job.add_parser(
  204. 'ls',
  205. help='查询某个 job 的 output 目录下文件夹内容'
  206. )
  207. job_ls.add_argument(
  208. 'directory',
  209. type=str,
  210. nargs='?',
  211. default='',
  212. help='输出目录路径'
  213. )
  214. # 'cp' 子命令,用于下载 output 目录下的文件到本地
  215. job_cp = subparser_job.add_parser(
  216. 'cp',
  217. help='下载某个 job 的 output 目录下的文件到本地'
  218. )
  219. job_cp.add_argument(
  220. 'result_file',
  221. type=str,
  222. help='结果文件路径'
  223. )
  224. job_cp.add_argument(
  225. 'local_path',
  226. type=str,
  227. help='本地保存路径'
  228. )
  229. # 许可证ID到许可证名称的映射
  230. license_mapping = {
  231. 1: '公共领域 (CC0)',
  232. 2: '署名 (CC BY 4.0)',
  233. 3: '署名-非商业性使用-相同方式共享 (CC BY-NC-SA 4.0)',
  234. 4: '署名-相同方式共享 (CC BY-SA 4.0)',
  235. 5: '署名-禁止演绎 (CC-BY-ND)',
  236. 6: '自由软件基金会 (GPL 2)',
  237. 7: '署名-允许演绎 (ODC-BY)',
  238. 8: '其他'
  239. }
  240. # 创建主命令解析器
  241. dataset = subparser_aistudio.add_parser(
  242. 'dataset',
  243. help='管理数据集,此命令不在支持,请使用新的命令',
  244. formatter_class=CustomHelpFormatter
  245. )
  246. # 构建许可证的帮助信息,每个选项单独一行
  247. license_help = (
  248. "数据集许可协议的ID,仅在设置public后生效。默认为1 (公共领域 CC0)。\n"
  249. "可选项包括:\n" + '\n'.join(f" {k}: {v}" for k, v in license_mapping.items())
  250. )
  251. # 添加 dataset 子命令
  252. datasets_create = dataset.add_subparsers(help='数据集操作')
  253. # 创建数据集的子命令(create)
  254. # aistudio datasets create [flags]
  255. #
  256. # flags:
  257. # --name ppocr_v1 (required) (-n)
  258. # --files ./file.zip (required) (文件路径,支持多文件上传)(-f)
  259. # --tags 大模型 (optional) (-t)
  260. # --public (optional, 默认不公开)(-p)
  261. # --license CC0 (optional,默认CC0,只在设置public后生效 )(-l)
  262. # --description testdata (optional) (-d)
  263. create = datasets_create.add_parser(
  264. 'create',
  265. help='创建数据集',
  266. formatter_class=CustomHelpFormatter
  267. )
  268. create.add_argument(
  269. '-n', '--name',
  270. type=str,
  271. required=True,
  272. help='数据集名称'
  273. )
  274. create.add_argument(
  275. '-f', '--files',
  276. type=str,
  277. required=True,
  278. nargs='+',
  279. help='本地文件路径,支持多个文件'
  280. )
  281. create.add_argument(
  282. '-p', '--public',
  283. action='store_true',
  284. help='是否公开数据集'
  285. )
  286. create.add_argument(
  287. '-l', '--license',
  288. type=int,
  289. required=False,
  290. choices=list(license_mapping.keys()),
  291. default=1,
  292. help=license_help
  293. )
  294. create.add_argument(
  295. '-d', '--description',
  296. type=str,
  297. required=False,
  298. help='数据集描述'
  299. )
  300. # # ** 上传数据集文件 ******************
  301. # aistudio datasets add [flags]
  302. #
  303. # flags:
  304. # --id 123645 (required) (数据集id) (-i)
  305. # --files ./file.zip (required) (文件路径)(-f)
  306. add = datasets_create.add_parser(
  307. 'add',
  308. help='上传数据集文件',
  309. formatter_class=CustomHelpFormatter
  310. )
  311. add.add_argument(
  312. '-id', '--id',
  313. type=int,
  314. required=True,
  315. help='数据集id'
  316. )
  317. add.add_argument(
  318. '-f', '--files',
  319. type=str,
  320. required=True,
  321. nargs='+',
  322. help='本地文件路径,支持多个文件'
  323. )
  324. # 新增model模块
  325. download = subparser_aistudio.add_parser(
  326. 'download',
  327. help='下载文件',
  328. formatter_class=CustomHelpFormatter
  329. )
  330. download.add_argument(
  331. '--model',
  332. type=str,
  333. help='模型ID,例如 myname/myrepoid'
  334. )
  335. download.add_argument(
  336. '--dataset',
  337. type=str,
  338. help='The id of the dataset to be downloaded. For download, '
  339. 'the id of either a model or dataset must be provided.')
  340. download.add_argument(
  341. '--revision',
  342. type=str,
  343. default=None,
  344. help='Revision of the entity.')
  345. download.add_argument(
  346. '--local_dir',
  347. type=str,
  348. default=None,
  349. help='File will be downloaded to local location specified by'
  350. 'local_dir, in this case.')
  351. download.add_argument(
  352. 'files',
  353. type=str,
  354. default=None,
  355. nargs='*',
  356. help='Specify relative path to the repository file(s) to download.'
  357. "(e.g 'tokenizer.json', 'dir/decoder_model.onnx').")
  358. download.add_argument(
  359. '--include',
  360. nargs='*',
  361. default=None,
  362. type=str,
  363. help='Glob patterns to match files to download.'
  364. 'Ignored if file is specified')
  365. download.add_argument(
  366. '--exclude',
  367. nargs='*',
  368. type=str,
  369. default=None,
  370. help='Glob patterns to exclude from files to download.'
  371. 'Ignored if file is specified')
  372. download.add_argument(
  373. '--token',
  374. type=str,
  375. default=None,
  376. help='A User Access Token'
  377. )
  378. download.add_argument(
  379. '--max-workers',
  380. type=int,
  381. default=DEFAULT_MAX_WORKERS,
  382. help='The maximum number of workers to download files.')
  383. upload = subparser_aistudio.add_parser(
  384. 'upload',
  385. help='上传文件',
  386. formatter_class=CustomHelpFormatter)
  387. upload.add_argument(
  388. 'repo_id',
  389. type=str,
  390. help='The ID of the repo to upload to (e.g. `username/repo-name`)')
  391. upload.add_argument(
  392. 'local_path',
  393. type=str,
  394. nargs='?',
  395. default=None,
  396. help='Optional, '
  397. 'Local path to the file or folder to upload. Defaults to current directory.'
  398. )
  399. upload.add_argument(
  400. 'path_in_repo',
  401. type=str,
  402. nargs='?',
  403. default=None,
  404. help='Optional, '
  405. 'Path of the file or folder in the repo. Defaults to the relative path of the file or folder.'
  406. )
  407. upload.add_argument(
  408. '--repo-type',
  409. choices=REPO_TYPE_SUPPORT,
  410. default=REPO_TYPE_MODEL,
  411. help='Type of the repo to upload to (e.g. `dataset`, `model`). Defaults to be `model`.',
  412. )
  413. upload.add_argument(
  414. '--include',
  415. nargs='*',
  416. type=str,
  417. help='Glob patterns to match files to upload.')
  418. upload.add_argument(
  419. '--exclude',
  420. nargs='*',
  421. type=str,
  422. help='Glob patterns to exclude from files to upload.')
  423. upload.add_argument(
  424. '--commit-message',
  425. type=str,
  426. default=None,
  427. help='The message of commit. Default to be `None`.')
  428. upload.add_argument(
  429. '--token',
  430. type=str,
  431. default=None,
  432. help='A User Access Token'
  433. )
  434. upload.add_argument(
  435. '--max-workers',
  436. type=int,
  437. default=min(8,
  438. os.cpu_count() + 4),
  439. help='The number of workers to use for uploading files.')
  440. return parser
  441. cache_home = os.getenv("AISTUDIO_CACHE_HOME", default=os.getenv("HOME"))
  442. TOKEN_FILE = os.path.expanduser(f'{cache_home}/.cache/aistudio/.auth/token')
  443. def save_token(token):
  444. """
  445. save to separate location
  446. """
  447. print(token)
  448. with open(TOKEN_FILE, 'w') as f:
  449. f.write(str(token))
  450. os.chmod(TOKEN_FILE, 0o600)
  451. def main():
  452. """CLI入口"""
  453. parser = init()
  454. args = sys.argv[1:]
  455. print(f"{args}")
  456. try:
  457. args = parser.parse_args(args)
  458. except:
  459. return
  460. if getattr(args, 'command', None) == 'upload':
  461. assert args.repo_id, '`repo_id` is required'
  462. assert args.repo_id.count(
  463. '/') == 1, 'repo_id should be in format of username/repo-name'
  464. repo_name: str = args.repo_id.split('/')[-1]
  465. parser.repo_id = args.repo_id
  466. # Check path_in_repo
  467. if args.local_path is None and os.path.isfile(repo_name):
  468. # Case 1: modelscope upload owner_name/test_repo
  469. parser.local_path = repo_name
  470. parser.path_in_repo = repo_name
  471. elif args.local_path is None and os.path.isdir(repo_name):
  472. # Case 2: modelscope upload owner_name/test_repo (run command line in the `repo_name` dir)
  473. # => upload all files in current directory to remote root path
  474. parser.local_path = repo_name
  475. parser.path_in_repo = '.'
  476. elif args.local_path is None:
  477. # Case 3: user provided only a repo_id that does not match a local file or folder
  478. # => the user must explicitly provide a local_path => raise exception
  479. raise ValueError(
  480. f"'{repo_name}' is not a local file or folder. Please set `local_path` explicitly."
  481. )
  482. elif args.path_in_repo is None and os.path.isfile(
  483. args.local_path):
  484. # Case 4: modelscope upload owner_name/test_repo /path/to/your_file.csv
  485. # => upload it to remote root path with same name
  486. parser.local_path = args.local_path
  487. parser.path_in_repo = os.path.basename(args.local_path)
  488. elif args.path_in_repo is None:
  489. # Case 5: modelscope upload owner_name/test_repo /path/to/your_folder
  490. # => upload all files in current directory to remote root path
  491. parser.local_path = args.local_path
  492. parser.path_in_repo = ''
  493. else:
  494. # Finally, if both paths are explicit
  495. parser.local_path = args.local_path
  496. parser.path_in_repo = args.path_in_repo
  497. if os.path.isfile(parser.local_path):
  498. upload_file(
  499. path_or_fileobj=parser.local_path,
  500. path_in_repo=parser.path_in_repo,
  501. repo_id=parser.repo_id,
  502. repo_type=args.repo_type,
  503. commit_message=args.commit_message,
  504. token=args.token,
  505. )
  506. elif os.path.isdir(parser.local_path):
  507. upload_folder(
  508. repo_id=parser.repo_id,
  509. folder_path=parser.local_path,
  510. path_in_repo=parser.path_in_repo,
  511. commit_message=args.commit_message,
  512. repo_type=args.repo_type,
  513. allow_patterns=convert_patterns(args.include),
  514. ignore_patterns=convert_patterns(args.exclude),
  515. max_workers=args.max_workers,
  516. token=args.token,
  517. )
  518. else:
  519. raise ValueError(f'{parser.local_path} is not a valid local path')
  520. print(f'Finished uploading to {parser.repo_id}')
  521. elif hasattr(args, 'model') and args.model:
  522. if len(args.files) == 1: # download single file
  523. model_file_download(
  524. args.model,
  525. args.files[0],
  526. local_dir=args.local_dir,
  527. revision=args.revision,
  528. token=args.token
  529. )
  530. elif len(
  531. args.files) > 1: # download specified multiple files.
  532. snapshot_download(
  533. repo_id=args.model,
  534. revision=args.revision,
  535. local_dir=args.local_dir,
  536. allow_patterns=args.files,
  537. max_workers=args.max_workers,
  538. token=args.token
  539. )
  540. else: # download repo
  541. snapshot_download(
  542. repo_id=args.model,
  543. revision=args.revision,
  544. local_dir=args.local_dir,
  545. allow_patterns=convert_patterns(args.include),
  546. ignore_patterns=convert_patterns(args.exclude),
  547. max_workers=args.max_workers,
  548. token=args.token
  549. )
  550. elif hasattr(args, 'dataset') and args.dataset:
  551. dataset_revision: str = args.revision if args.revision else DEFAULT_DATASET_REVISION
  552. if len(args.files) == 1: # download single file
  553. file_download(
  554. args.dataset,
  555. args.files[0],
  556. local_dir=args.local_dir,
  557. revision=dataset_revision,
  558. repo_type=REPO_TYPE_DATASET,
  559. token=args.token
  560. )
  561. elif len(
  562. args.files) > 1: # download specified multiple files.
  563. snapshot_download(
  564. repo_id=args.dataset,
  565. revision=dataset_revision,
  566. local_dir=args.local_dir,
  567. allow_patterns=args.files,
  568. max_workers=args.max_workers,
  569. token=args.token
  570. )
  571. else: # download repo
  572. snapshot_download(
  573. repo_id=args.dataset,
  574. revision=dataset_revision,
  575. local_dir=args.local_dir,
  576. allow_patterns=convert_patterns(args.include),
  577. ignore_patterns=convert_patterns(args.exclude),
  578. max_workers=args.max_workers,
  579. token=args.token
  580. )
  581. print(
  582. f'\nSuccessfully Downloaded from dataset {args.dataset}.\n'
  583. )
  584. elif "token" in args:
  585. pipeline.set_config(args)
  586. elif "summit_name" in args:
  587. pipeline.create(args)
  588. elif "query_pipeline_id" in args:
  589. pipeline.query(args)
  590. elif "stop_pipeline_id" in args:
  591. if not args.force:
  592. # 二次确认
  593. if not click.confirm('Do you want to continue?', default=False):
  594. log.info('Aborted.')
  595. return
  596. log.info('Confirmed.')
  597. pipeline.stop(args)
  598. elif "directory" in args:
  599. # 查询某个 job 的 output 目录下文件夹内容
  600. pipeline.list_output_files(args)
  601. elif "result_file" in args and "local_path" in args:
  602. # 下载某个 job 的 output 目录下的文件到本地
  603. pipeline.download_output_file(args)
  604. elif "name" in args and "files" in args:
  605. # 创建数据集
  606. log.error("This command is not supported any more")
  607. pipeline.create_dataset(args)
  608. elif "id" in args and "files" in args:
  609. # 上传数据集文件
  610. log.error("This command is not supported any more")
  611. pipeline.add_file(args)
  612. else:
  613. log.info("无效的命令")
  614. if __name__ == '__main__':
  615. main()