hub.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2023 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. 本文件实现了模型库hub接口封装
  10. TODO:
  11. 当前脚本后续将移动至sdk目录下, 但用法将发生变化, 需和pm确认
  12. 旧:
  13. from aistudio_sdk.hub import create_repo
  14. create_repo()
  15. 新:
  16. from aistudio_sdk import hub
  17. hub.create_repo()
  18. Authors: linyichong(linyichong@baidu.com)
  19. Date: 2023/08/21
  20. """
  21. from typing import Optional
  22. import requests
  23. import os
  24. import io
  25. import logging
  26. import traceback
  27. from pathlib import Path
  28. from aistudio_sdk.constant.err_code import ErrorEnum
  29. from aistudio_sdk.requests.hub import request_aistudio_hub, request_aistudio_app_service
  30. from aistudio_sdk.requests.hub import request_aistudio_git_file_info, commit_files
  31. from aistudio_sdk.requests.hub import request_aistudio_git_file_type, request_aistudio_git_files_type
  32. from aistudio_sdk.requests.hub import request_aistudio_git_upload_access
  33. from aistudio_sdk.requests.hub import request_bos_upload
  34. from aistudio_sdk.requests.hub import request_aistudio_git_upload_pointer
  35. from aistudio_sdk.requests.hub import request_aistudio_git_upload_common, request_single_git_upload_common
  36. from aistudio_sdk.requests.hub import get_exist_file_old_sha
  37. from aistudio_sdk.requests.hub import request_aistudio_repo_visible
  38. from aistudio_sdk.requests.hub import request_aistudio_verify_lfs_file, request_single_git_upload_pointer
  39. from aistudio_sdk.utils.util import convert_to_dict_object, is_valid_host, calculate_sha256
  40. from aistudio_sdk.utils.util import err_resp
  41. from aistudio_sdk.utils.util import (extract_yaml_block, is_readme_md, get_file_size,
  42. get_file_hash, thread_executor)
  43. from aistudio_sdk import log
  44. from aistudio_sdk import config
  45. from aistudio_sdk.dot import post_upload_statistic_async
  46. from typing import (List, Union, BinaryIO, Iterable, Callable, Generator, TypeVar,
  47. Dict, Any, Literal, Iterator)
  48. from dataclasses import dataclass
  49. from fnmatch import fnmatch
  50. from contextlib import contextmanager
  51. T = TypeVar('T')
  52. __all__ = [
  53. "create_repo",
  54. "upload",
  55. "file_exists",
  56. "upload_folder",
  57. "upload_file"
  58. ]
  59. UploadMode = Literal['lfs', 'normal']
  60. FORBIDDEN_FOLDERS = ['.git', '.cache']
  61. class UploadFileException(Exception):
  62. """
  63. 上传文件异常
  64. """
  65. pass
  66. class Hub():
  67. """Hub类"""
  68. OBJECT_NAME = "hub"
  69. def __init__(self):
  70. """初始化函数,从本地磁盘加载AI Studio认证令牌。
  71. Args:
  72. 无参数。
  73. Returns:
  74. 无返回值。
  75. """
  76. # 当用户已经设置了AISTUDIO_ACCESS_TOKEN环境变量,那么优先读取环境变量,忽略本地磁盘存的token
  77. # 未设置时才读存本地token
  78. if not os.getenv("AISTUDIO_ACCESS_TOKEN", default=""):
  79. cache_home = os.getenv("AISTUDIO_CACHE_HOME", default=os.getenv("HOME"))
  80. token_file_path = f'{cache_home}/.cache/aistudio/.auth/token'
  81. if os.path.exists(token_file_path):
  82. with open(token_file_path, 'r') as file:
  83. os.environ["AISTUDIO_ACCESS_TOKEN"] = file.read().strip()
  84. self.upload_checker = UploadingCheck()
  85. def create_repo(self, **kwargs):
  86. """
  87. 创建一个repo仓库并返回创建成功后的信息。
  88. Params:
  89. repo_id (str): 仓库名称,格式为user_name/repo_name 或者 repo_name,必填。
  90. repo_type (str): 仓库类型,取值为app/model,分别为应用仓库和模型仓库。如果未指定,默认为model。
  91. app_name (str): 应用名称,如果repo_type为app,则必填。默认值为repo_id (如果不填,后端自动生成)。
  92. app_sdk (str): 应用SDK, 如果repo_type为app,则必填,可以填写 streamlit, gradio, static 三种
  93. version (str): streamlit 或 gradio 版本,必填
  94. * gradio版本支持"4.26.0", "4.0.0"
  95. * streamlit版本支持"1.33.0", "1.30.0"
  96. model_name (str): 模型名称,如果repo_type为model,则必填。默认值为repo_id。
  97. desc (str): 仓库描述,可选,默认为空。
  98. license (str): 仓库许可证,可选,默认为"Apache License 2.0"。
  99. private (bool): 是否私有仓库,可选,默认为False。
  100. token (str): 认证令牌,可选,默认为环境变量的值。
  101. Demo:
  102. 创建应用仓库:
  103. create_repo(repo_id='app_repo_0425',
  104. app_sdk='streamlit',
  105. version="1.33.0"
  106. desc='my app demo')
  107. Returns:
  108. dict: 仓库创建结果。
  109. """
  110. params = {}
  111. if "repo_id" not in kwargs:
  112. return err_resp(ErrorEnum.PARAMS_INVALID.code, ErrorEnum.PARAMS_INVALID.message)
  113. # 设置默认repo_type为'model'
  114. repo_type = kwargs.get('repo_type', 'model')
  115. if repo_type == 'app':
  116. if 'app_name' not in kwargs:
  117. return err_resp(ErrorEnum.PARAMS_INVALID.code,
  118. ErrorEnum.PARAMS_INVALID.message + "should provide param app_name")
  119. app_sdk = kwargs.get('app_sdk')
  120. if not app_sdk or app_sdk not in ['streamlit', 'gradio', 'static']:
  121. return err_resp(ErrorEnum.PARAMS_INVALID.code,
  122. ErrorEnum.PARAMS_INVALID.message + "app_sdk should be streamlit, gradio or static.")
  123. if app_sdk == "streamlit":
  124. if 'version' not in kwargs:
  125. return err_resp(ErrorEnum.PARAMS_INVALID.code,
  126. "streamlit version needed.")
  127. params["streamlitVersion"] = kwargs['version']
  128. if app_sdk == "gradio":
  129. if 'version' not in kwargs:
  130. return err_resp(ErrorEnum.PARAMS_INVALID.code,
  131. "gradio version needed.")
  132. params["gradioVersion"] = kwargs['version']
  133. elif repo_type == 'model' and 'model_name' not in kwargs:
  134. kwargs['model_name'] = kwargs.get('repo_id')
  135. if 'private' in kwargs and not isinstance(kwargs['private'], bool):
  136. return err_resp(ErrorEnum.PARAMS_INVALID.code, "private should be bool type.")
  137. for key in ['repo_id', 'model_name', 'license', 'token']:
  138. if key in kwargs:
  139. if not isinstance(kwargs[key], str):
  140. return err_resp(ErrorEnum.PARAMS_INVALID.code, "should be str type: " + key)
  141. kwargs[key] = kwargs[key].strip()
  142. if not kwargs[key]:
  143. return err_resp(ErrorEnum.PARAMS_INVALID.code, "should not be empty: " + key)
  144. if not os.getenv("AISTUDIO_ACCESS_TOKEN") and 'token' not in kwargs:
  145. return err_resp(ErrorEnum.TOKEN_IS_EMPTY.code, ErrorEnum.TOKEN_IS_EMPTY.message)
  146. if 'desc' in kwargs and not isinstance(kwargs['desc'], str):
  147. return err_resp(ErrorEnum.PARAMS_INVALID.code, ErrorEnum.PARAMS_INVALID.message)
  148. repo_name_raw = kwargs['repo_id']
  149. if "/" in repo_name_raw:
  150. user_name, repo_name = repo_name_raw.split('/')
  151. user_name = user_name.strip()
  152. repo_name = repo_name.strip()
  153. if not repo_name or not user_name:
  154. return err_resp(ErrorEnum.PARAMS_INVALID.code,
  155. "user_name or repo_name is empty. repo_id should be user_name/repo_name format.")
  156. kwargs['repo_id'] = repo_name
  157. else:
  158. kwargs['repo_id'] = repo_name_raw.strip()
  159. # return err_resp(ErrorEnum.PARAMS_INVALID.code,
  160. # "r epo_id should be user_name/repo_name format.")
  161. if repo_type == 'model':
  162. more_params = {
  163. 'repoType': 0 if kwargs.get('private') else 1,
  164. 'repoName': kwargs['repo_id'],
  165. 'modelName': kwargs.get('model_name', ''), # 添加模型名
  166. 'desc': kwargs.get('desc', ''),
  167. 'license': kwargs.get('license', 'Apache License 2.0'),
  168. 'token': kwargs.get('token', '')
  169. }
  170. params.update(more_params)
  171. resp = convert_to_dict_object(request_aistudio_hub(**params))
  172. else:
  173. more_params = {
  174. 'repoType': 0 if kwargs.get('private') else 1,
  175. 'repoName': kwargs['repo_id'],
  176. 'appName': kwargs.get('app_name', ''),
  177. 'appType': kwargs.get('app_sdk', ''),
  178. 'desc': kwargs.get('desc', ''),
  179. 'license': kwargs.get('license', 'Apache License 2.0'),
  180. 'token': kwargs.get('token', '')
  181. }
  182. params.update(more_params)
  183. resp_raw = request_aistudio_app_service(**params)
  184. log.debug(f"create_repo resp: {resp_raw}")
  185. resp = convert_to_dict_object(resp_raw)
  186. log.debug(f"create_repo resp dict: {resp}")
  187. if 'errorCode' in resp and resp['errorCode'] != 0:
  188. log.error(f"create_repo failed: {resp}")
  189. if "repo already created" in resp['errorMsg']:
  190. res = err_resp(ErrorEnum.REPO_ALREADY_EXIST.code,
  191. resp['errorMsg'],
  192. resp['errorCode'],
  193. resp['logId']) # 错误logid透传
  194. else:
  195. res = err_resp(ErrorEnum.AISTUDIO_CREATE_REPO_FAILED.code,
  196. resp['errorMsg'],
  197. resp['errorCode'],
  198. resp['logId'])
  199. return res
  200. if repo_type == 'model':
  201. res = {
  202. 'model_name': resp['result']['modelName'],
  203. 'repo_id': resp['result']['repoName'],
  204. 'private': True if resp['result']['repoType'] == 0 else False,
  205. 'desc': resp['result']['desc'],
  206. 'license': resp['result']['license']
  207. }
  208. else:
  209. res = {
  210. 'app_id': resp['result']['appId'],
  211. 'app_name': resp['result']['appName'],
  212. 'repo_id': resp['result']['repoName'],
  213. 'desc': resp['result']['desc'],
  214. 'license': resp['result']['license']
  215. }
  216. return res
  217. def _upload_lfs_file(self, settings, file_path, file_size):
  218. """
  219. 上传文件
  220. settings: 上传文件的配置信息
  221. settings = {
  222. 'upload'[bool]: True or False
  223. 'upload_href'[str]: upload url
  224. 'sts_token'[dict]: sts token
  225. {
  226. "bos_host":"",
  227. "bucket_name": "",
  228. "key":"",
  229. "access_key_id": "",
  230. "secret_access_key": "",
  231. "session_token": "",
  232. "expiration": ""
  233. }
  234. }
  235. file_path: 本地文件路径
  236. """
  237. if not settings.get('upload'):
  238. logging.info("file already exists, skip the upload.")
  239. return True
  240. upload_href = settings['upload_href']
  241. sts_token = settings.get('sts_token', {})
  242. is_sts_valid = False
  243. if sts_token and sts_token.get("bos_host"):
  244. is_sts_valid = True
  245. is_http_valid = True if upload_href and file_size < config.LFS_FILE_SIZE_LIMIT_PUT else False
  246. def _uploading_using_sts():
  247. """
  248. 使用sts上传文件
  249. """
  250. from aistudio_sdk.utils.bos_sdk import sts_client, upload_file, upload_super_file
  251. try:
  252. client = sts_client(sts_token.get("bos_host"), sts_token.get("access_key_id"),
  253. sts_token.get("secret_access_key"), sts_token.get("session_token"))
  254. res = upload_super_file(client,
  255. bucket=sts_token.get("bucket_name"), file=file_path, key=sts_token.get("key"))
  256. return res
  257. except Exception as e:
  258. raise UploadFileException(e)
  259. def _uploading_using_http():
  260. """
  261. 使用http上传文件
  262. """
  263. try:
  264. res = request_bos_upload(upload_href, file_path)
  265. if 'error_code' in res and res['error_code'] != ErrorEnum.SUCCESS.code:
  266. return res
  267. return True
  268. except Exception as e:
  269. raise UploadFileException(e)
  270. functions = []
  271. if is_sts_valid:
  272. functions.append(_uploading_using_sts)
  273. if is_http_valid:
  274. functions.append(_uploading_using_http)
  275. if not os.environ.get("PERFER_STS_UPLOAD", default="true") == "true":
  276. functions.reverse()
  277. if not functions:
  278. logging.error("no upload method available.")
  279. return False
  280. upload_success = False
  281. for func in functions:
  282. try:
  283. logging.info(f"uploading file using {func.__name__}")
  284. res = func()
  285. if res is True:
  286. logging.info(f"upload lfs file success. {func.__name__}")
  287. upload_success = True
  288. break
  289. else:
  290. logging.error(f"upload lfs file failed. {func.__name__}: {res}")
  291. except UploadFileException as e:
  292. logging.error(f"upload lfs file failed. {func.__name__}: {e}")
  293. logging.debug(traceback.format_exc())
  294. return upload_success
  295. @staticmethod
  296. def _get_suffix_forbidden(repo_id):
  297. try:
  298. url = "{}{}".format(
  299. os.getenv("STUDIO_MODEL_API_URL_PREFIX", default=config.STUDIO_MODEL_API_URL_PREFIX_DEFAULT),
  300. config.BLACK_LIST_URL
  301. )
  302. if repo_id:
  303. url = f"{url}?repoId={repo_id}"
  304. response = requests.get(url)
  305. if response.status_code == 200:
  306. r = response.json()
  307. if r['errorCode'] == 0:
  308. return r['result']
  309. else:
  310. return []
  311. except Exception as e:
  312. log.error(f"get black list fail:{e}")
  313. return []
  314. def file_exists(self, repo_id, filename, *args, **kwargs):
  315. """
  316. 文件是否存在
  317. params:
  318. repo_id: 仓库id,格式为user_name/repo_name
  319. filename: 仓库中的文件路径
  320. revision: 分支名
  321. token: 认证令牌
  322. """
  323. # 参数检查
  324. str_params_not_valid = 'params not valid.'
  325. kwargs['repo_id'] = repo_id
  326. kwargs['filename'] = filename
  327. # 检查入参值的格式类型
  328. for key in ['filename', 'repo_id', 'revision', 'token']:
  329. if key in kwargs:
  330. if type(kwargs[key]) != str:
  331. return err_resp(ErrorEnum.PARAMS_INVALID.code,
  332. ErrorEnum.PARAMS_INVALID.message)
  333. kwargs[key] = kwargs[key].strip()
  334. if not kwargs[key]:
  335. return err_resp(ErrorEnum.PARAMS_INVALID.code,
  336. ErrorEnum.PARAMS_INVALID.message)
  337. revision = kwargs['revision'] if kwargs.get('revision') else 'master'
  338. file_path = kwargs['filename']
  339. token = kwargs['token'] if 'token' in kwargs else ''
  340. repo_name = kwargs['repo_id']
  341. if "/" not in repo_name:
  342. return err_resp(ErrorEnum.PARAMS_INVALID.code,
  343. ErrorEnum.PARAMS_INVALID.message)
  344. user_name, repo_name = repo_name.split('/')
  345. user_name = user_name.strip()
  346. repo_name = repo_name.strip()
  347. if not repo_name or not user_name:
  348. return err_resp(ErrorEnum.PARAMS_INVALID.code,
  349. ErrorEnum.PARAMS_INVALID.message)
  350. call_host = os.getenv("STUDIO_GIT_HOST", default=config.STUDIO_GIT_HOST_DEFAULT)
  351. if not is_valid_host(call_host):
  352. return err_resp(ErrorEnum.PARAMS_INVALID.code,
  353. 'host not valid.')
  354. if os.environ.get("SKIP_REPO_VISIBLE_CHECK", default="false") != "true":
  355. # 检查仓库可见权限(他人的预发布仓库不能下载、查看)
  356. params = {
  357. 'repoId': kwargs['repo_id'],
  358. 'token': kwargs['token'] if 'token' in kwargs else ''
  359. }
  360. resp = convert_to_dict_object(request_aistudio_repo_visible(**params))
  361. if 'errorCode' in resp and resp['errorCode'] != 0:
  362. res = err_resp(ErrorEnum.AISTUDIO_NO_REPO_READ_AUTH.code,
  363. resp['errorMsg'],
  364. resp['errorCode'],
  365. resp['logId'])
  366. return res
  367. # 查询文件是否存在
  368. info_res = request_aistudio_git_file_info(call_host, user_name, repo_name, file_path,
  369. revision, token)
  370. if get_exist_file_old_sha(info_res) == '':
  371. return False
  372. else:
  373. return True
  374. def _prepare_upload_folder(
  375. self,
  376. folder_path_or_files: Union[str, Path, List[str], List[Path]],
  377. path_in_repo: str,
  378. allow_patterns: Optional[Union[List[str], str]] = None,
  379. ignore_patterns: Optional[Union[List[str], str]] = None,
  380. ):
  381. folder_path = None
  382. files_path = None
  383. if isinstance(folder_path_or_files, list):
  384. if os.path.isfile(folder_path_or_files[0]):
  385. files_path = folder_path_or_files
  386. else:
  387. raise ValueError('Uploading multiple folders is not supported now.')
  388. else:
  389. if os.path.isfile(folder_path_or_files):
  390. files_path = [folder_path_or_files]
  391. else:
  392. folder_path = folder_path_or_files
  393. if files_path is None:
  394. self.upload_checker.check_folder(folder_path)
  395. folder_path = Path(folder_path).expanduser().resolve()
  396. if not folder_path.is_dir():
  397. raise ValueError(f"Provided path: '{folder_path}' is not a directory")
  398. # List files from folder
  399. relpath_to_abspath = {
  400. path.relative_to(folder_path).as_posix(): path
  401. for path in sorted(folder_path.glob('**/*')) # sorted to be deterministic
  402. if path.is_file()
  403. }
  404. else:
  405. relpath_to_abspath = {}
  406. for path in files_path:
  407. if os.path.isfile(path):
  408. self.upload_checker.check_file(path)
  409. relpath_to_abspath[os.path.basename(path)] = path
  410. # Filter files
  411. filtered_repo_objects = list(
  412. UploadingCheck.filter_repo_objects(
  413. relpath_to_abspath.keys(), allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
  414. )
  415. )
  416. prefix = f"{path_in_repo.strip('/')}/" if path_in_repo else ''
  417. prepared_repo_objects = [
  418. (prefix + relpath, str(relpath_to_abspath[relpath]))
  419. for relpath in filtered_repo_objects
  420. ]
  421. return prepared_repo_objects
  422. def upload_file(
  423. self,
  424. *,
  425. path_or_fileobj: Union[str, Path, bytes, BinaryIO],
  426. path_in_repo: str,
  427. repo_id: str,
  428. token: Union[str, None] = None,
  429. repo_type: Optional[str] = config.REPO_TYPE_MODEL,
  430. commit_message: Optional[str] = None,
  431. revision: Optional[str] = config.DEFAULT_REPOSITORY_REVISION,
  432. ):
  433. """
  434. upload single file
  435. """
  436. if repo_type not in config.REPO_TYPE_SUPPORT:
  437. raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {config.REPO_TYPE_SUPPORT}')
  438. if not path_or_fileobj:
  439. raise ValueError('Path or file object cannot be empty!')
  440. if isinstance(path_or_fileobj, (str, Path)):
  441. path_or_fileobj = os.path.abspath(os.path.expanduser(path_or_fileobj))
  442. path_in_repo = path_in_repo or os.path.basename(path_or_fileobj)
  443. else:
  444. # If path_or_fileobj is bytes or BinaryIO, then path_in_repo must be provided
  445. if not path_in_repo:
  446. raise ValueError('Arg `path_in_repo` cannot be empty!')
  447. commit_message = (
  448. commit_message if commit_message is not None else f'Add {path_in_repo}'
  449. )
  450. # Read file content if path_or_fileobj is a file-like object (BinaryIO)
  451. if isinstance(path_or_fileobj, io.BufferedIOBase):
  452. path_or_fileobj = path_or_fileobj.read()
  453. self.upload_folder(repo_id=repo_id, folder_path=path_or_fileobj,
  454. path_in_repo=path_in_repo, token=token, repo_type=repo_type, commit_message=commit_message,
  455. revision=revision, single=True)
  456. def upload_folder(
  457. self,
  458. repo_id: str,
  459. folder_path: Union[str, Path, List[str], List[Path]] = None,
  460. path_in_repo: Optional[str] = '',
  461. commit_message: Optional[str] = None,
  462. token: Union[str, None] = None,
  463. repo_type: Optional[str] = config.REPO_TYPE_MODEL,
  464. allow_patterns: Optional[Union[List[str], str]] = None,
  465. ignore_patterns: Optional[Union[List[str], str]] = None,
  466. max_workers: int = config.DEFAULT_MAX_WORKERS,
  467. revision: Optional[str] = config.DEFAULT_REPOSITORY_REVISION,
  468. single: bool = False
  469. ):
  470. """upload"""
  471. if repo_type not in config.REPO_TYPE_SUPPORT:
  472. raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {config.REPO_TYPE_SUPPORT}')
  473. if token is None:
  474. token = os.getenv("AISTUDIO_ACCESS_TOKEN")
  475. allow_patterns = allow_patterns if allow_patterns else None
  476. ignore_patterns = ignore_patterns if ignore_patterns else None
  477. # Ignore .git folder
  478. if ignore_patterns is None:
  479. ignore_patterns = []
  480. elif isinstance(ignore_patterns, str):
  481. ignore_patterns = [ignore_patterns]
  482. commit_message = (
  483. commit_message if commit_message is not None else f'Upload folder to repo'
  484. )
  485. if single:
  486. prepared_repo_objects = [(path_in_repo, folder_path)]
  487. else:
  488. # Get the list of files to upload, e.g. [('data/abc.png', '/path/to/abc.png'), ...]
  489. prepared_repo_objects = self._prepare_upload_folder(
  490. folder_path_or_files=folder_path,
  491. path_in_repo=path_in_repo,
  492. allow_patterns=allow_patterns,
  493. ignore_patterns=ignore_patterns,
  494. )
  495. git_host = os.getenv("STUDIO_GIT_HOST", default=config.STUDIO_GIT_HOST_DEFAULT)
  496. user_name, repo_name = repo_id.split('/')
  497. user_name = user_name.strip()
  498. repo_name = repo_name.strip()
  499. if not repo_name or not user_name:
  500. raise ValueError("repo_name or user_name is empty,abort upload.")
  501. repo_path_list = []
  502. for name, _ in prepared_repo_objects:
  503. repo_path_list.append(name)
  504. if len(repo_path_list) == 0:
  505. return
  506. lfs_map = request_aistudio_git_files_type(git_host, user_name, repo_name,
  507. revision, repo_path_list, token)
  508. lfs_local_path_map = {}
  509. for remote_path, local_path in prepared_repo_objects:
  510. lfs_local_path_map[local_path] = lfs_map[remote_path]
  511. self.upload_checker.check_normal_files(
  512. file_path_list=[item for _, item in prepared_repo_objects],
  513. lfs_map=lfs_local_path_map
  514. )
  515. black_extensions = self._get_suffix_forbidden(repo_id)
  516. @thread_executor(max_workers=max_workers, disable_tqdm=False)
  517. def _upload_items(item_pair, log_list):
  518. file_path_in_repo, file_path = item_pair
  519. if is_readme_md(file_path=file_path) and file_path_in_repo == 'README.md' and revision == "master":
  520. try:
  521. url = "{}{}".format(
  522. os.getenv("STUDIO_MODEL_API_URL_PREFIX", default=config.STUDIO_MODEL_API_URL_PREFIX_DEFAULT),
  523. config.README_CHECK_URL)
  524. yaml_content = extract_yaml_block(file_path)
  525. payload = {
  526. "yaml": yaml_content,
  527. "repoId": repo_id
  528. }
  529. headers = {
  530. "Content-Type": "application/json"
  531. }
  532. response = requests.post(url, json=payload, headers=headers, timeout=(10, 10))
  533. if response.status_code == 200:
  534. data = response.json()
  535. if data.get('errorCode') == 0:
  536. log.debug(f"调用成功,logId:{data.get('logId')}")
  537. else:
  538. error_msg = data.get("errorMsg")
  539. log.error(f"check readme fail:{error_msg},skip{file_path}")
  540. log_list.append((file_path, f"check readme fail{error_msg}"))
  541. return None
  542. except Exception as e:
  543. log.info(f"check readme fail:{e}")
  544. log_list.append((file_path, f"check readme fail:{e}"))
  545. return None
  546. suffix = Path(file_path).suffix.lower()
  547. if black_extensions and suffix in black_extensions:
  548. log.info(f"File:{file_path} forbidden! Skip.")
  549. log_list.append((file_path, "file type forbidden"))
  550. return None
  551. hash_info_d: dict = get_file_hash(
  552. file_path_or_obj=file_path,
  553. )
  554. file_size: int = hash_info_d['file_size']
  555. file_hash: str = hash_info_d['file_hash']
  556. return self._upload_and_gather_commit_info(
  557. repo_id=repo_id,
  558. sha256=file_hash,
  559. size=file_size,
  560. data=file_path,
  561. token=token,
  562. revision=revision,
  563. file_path_in_repo=file_path_in_repo,
  564. git_host=git_host,
  565. is_lfs=lfs_map.get(file_path_in_repo),
  566. log_list=log_list
  567. )
  568. skip_list = []
  569. uploaded_item_raw = _upload_items(
  570. prepared_repo_objects,
  571. log_list=skip_list
  572. )
  573. uploaded_item_list = [item for item in uploaded_item_raw if item is not None]
  574. if len(uploaded_item_list) == 0 or uploaded_item_list is None:
  575. log.error('nothing to commit')
  576. return
  577. commit_files(
  578. log_list=skip_list,
  579. repo_id=repo_id,
  580. revision=revision,
  581. commit_message=commit_message,
  582. file_quads=uploaded_item_list,
  583. token=token
  584. )
  585. if len(skip_list) > 0:
  586. print('these files were skipped with reasons:')
  587. for local_path, reason in skip_list:
  588. print(f"{local_path}: {reason}")
  589. def _upload_and_gather_commit_info(
  590. self,
  591. *,
  592. repo_id: str,
  593. sha256: str,
  594. size: int,
  595. data: str,
  596. token: str,
  597. revision: str,
  598. file_path_in_repo: str,
  599. git_host: str,
  600. is_lfs: bool,
  601. log_list
  602. ):
  603. if "/" not in repo_id:
  604. raise ValueError("repo_id should be user_name/repo_name format.")
  605. user_name, repo_name = repo_id.split('/')
  606. user_name = user_name.strip()
  607. repo_name = repo_name.strip()
  608. if not repo_name or not user_name:
  609. raise ValueError("repo_name or user_name is empty,abort upload.")
  610. if is_lfs:
  611. try:
  612. pre_res = request_aistudio_git_upload_access(git_host, user_name, repo_name, revision,
  613. size, sha256, token)
  614. except Exception as e:
  615. log.error(f"{data} request upload_access fail,skip,{e}")
  616. log_list.append((data, "request upload_access fail"))
  617. return None
  618. logging.debug(f"the request_aistudio_git_upload_access res: {pre_res}")
  619. if 'error_code' in pre_res and pre_res['error_code'] != ErrorEnum.SUCCESS.code:
  620. log.error(f"{data} upload fail due to request git upload error:{pre_res}")
  621. log_list.append((data, "upload fail due to request git upload error"))
  622. return None
  623. if not pre_res.get('upload'):
  624. log.info(f'file {data} with sha {sha256[:8]} has already uploaded.')
  625. return file_path_in_repo, data, is_lfs, sha256
  626. upload_res = self._upload_lfs_file(pre_res, data, size)
  627. if not upload_res:
  628. log.error(f"upload this lfs file {data} failed. 文件上传终止")
  629. log_list.append((data, "upload lfs file failed,server error "))
  630. return None
  631. if pre_res.get("verify_href"):
  632. verify_res = request_aistudio_verify_lfs_file(pre_res.get("verify_href"), sha256, size, token)
  633. logging.info(f"verify lfs file res: {verify_res}")
  634. if 'error_code' in verify_res and verify_res['error_code'] != ErrorEnum.SUCCESS.code:
  635. logging.error(f"verify lfs file failed:{data}.")
  636. log_list.append((data, "verify lfs file failed"))
  637. return None
  638. # 第五步:上传LFS指针文件(到仓库)
  639. # lfs_res = request_single_git_upload_pointer(git_host, user_name, repo_name, revision,
  640. # sha256, size, file_path_in_repo, token)
  641. return file_path_in_repo, data, is_lfs, sha256
  642. else:
  643. log.debug("Start uploading this common file.")
  644. # 如果大小超标,报错返回
  645. if size > config.COMMON_FILE_SIZE_LIMIT:
  646. log.error(f"File:{data} is larger than 5MB for a common file. Fail")
  647. log_list.append((data, "larger than 5MB for a common file"))
  648. return None
  649. return file_path_in_repo, data, is_lfs, sha256
  650. class UploadingCheck:
  651. """
  652. check class
  653. """
  654. def __init__(
  655. self,
  656. max_file_count: int = config.UPLOAD_MAX_FILE_COUNT,
  657. max_file_count_in_dir: int = config.UPLOAD_MAX_FILE_COUNT_IN_DIR,
  658. max_file_size: int = config.UPLOAD_MAX_FILE_SIZE,
  659. size_threshold_to_enforce_lfs: int = config.UPLOAD_SIZE_THRESHOLD_TO_ENFORCE_LFS,
  660. normal_file_size_total_limit: int = config.UPLOAD_NORMAL_FILE_SIZE_TOTAL_LIMIT,
  661. ):
  662. self.max_file_count = max_file_count
  663. self.max_file_count_in_dir = max_file_count_in_dir
  664. self.max_file_size = max_file_size
  665. self.size_threshold_to_enforce_lfs = size_threshold_to_enforce_lfs
  666. self.normal_file_size_total_limit = normal_file_size_total_limit
  667. def check_file(self, file_path_or_obj):
  668. """
  669. check size
  670. """
  671. if isinstance(file_path_or_obj, (str, Path)):
  672. if not os.path.exists(file_path_or_obj):
  673. raise ValueError(f'File {file_path_or_obj} does not exist')
  674. file_size: int = get_file_size(file_path_or_obj)
  675. if file_size > self.max_file_size:
  676. log.warn(f'File exceeds size limit: {self.max_file_size / (1024 ** 3)} GB, '
  677. f'got {round(file_size / (1024 ** 3), 4)} GB')
  678. def check_folder(self, folder_path: Union[str, Path]):
  679. """
  680. check
  681. """
  682. file_count = 0
  683. dir_count = 0
  684. if isinstance(folder_path, str):
  685. folder_path = Path(folder_path)
  686. for item in folder_path.iterdir():
  687. if item.is_file():
  688. file_count += 1
  689. item_size: int = get_file_size(item)
  690. if item_size > self.max_file_size:
  691. log.warn(f'File {item} exceeds size limit: {self.max_file_size / (1024 ** 3)} GB, '
  692. f'got {round(item_size / (1024 ** 3), 4)} GB')
  693. elif item.is_dir():
  694. dir_count += 1
  695. sub_file_count, sub_dir_count = self.check_folder(item)
  696. if (sub_file_count + sub_dir_count) > self.max_file_count_in_dir:
  697. raise ValueError(f'Directory {item} contains {sub_file_count + sub_dir_count} items '
  698. f'and exceeds limit: {self.max_file_count_in_dir}')
  699. file_count += sub_file_count
  700. dir_count += sub_dir_count
  701. if file_count > self.max_file_count:
  702. raise ValueError(f'Total file count {file_count} and exceeds limit: {self.max_file_count}')
  703. return file_count, dir_count
  704. def check_normal_files(self, file_path_list: List[Union[str, Path]], lfs_map: dict):
  705. """
  706. check
  707. """
  708. normal_file_list = [item for item in file_path_list if not lfs_map[item]]
  709. total_size = sum([get_file_size(item) for item in normal_file_list])
  710. if total_size > self.normal_file_size_total_limit:
  711. raise ValueError(f'Total size of non-lfs files {total_size / (1024 * 1024)}MB '
  712. f'and exceeds limit: {self.normal_file_size_total_limit / (1024 * 1024)}MB')
  713. @staticmethod
  714. def filter_repo_objects(
  715. items: Iterable[T],
  716. *,
  717. allow_patterns: Optional[Union[List[str], str]] = None,
  718. ignore_patterns: Optional[Union[List[str], str]] = None,
  719. key: Optional[Callable[[T], str]] = None,
  720. ):
  721. """Filter repo objects based on an allowlist and a denylist.
  722. Input must be a list of paths (`str` or `Path`) or a list of arbitrary objects.
  723. In the later case, `key` must be provided and specifies a function of one argument
  724. that is used to extract a path from each element in iterable.
  725. Patterns are Unix shell-style wildcards which are NOT regular expressions. See
  726. https://docs.python.org/3/library/fnmatch.html for more details.
  727. Args:
  728. items (`Iterable`):
  729. List of items to filter.
  730. allow_patterns (`str` or `List[str]`, *optional*):
  731. Patterns constituting the allowlist. If provided, item paths must match at
  732. least one pattern from the allowlist.
  733. ignore_patterns (`str` or `List[str]`, *optional*):
  734. Patterns constituting the denylist. If provided, item paths must not match
  735. any patterns from the denylist.
  736. key (`Callable[[T], str]`, *optional*):
  737. Single-argument function to extract a path from each item. If not provided,
  738. the `items` must already be `str` or `Path`.
  739. Returns:
  740. Filtered list of objects, as a generator.
  741. Raises:
  742. :class:`ValueError`:
  743. If `key` is not provided and items are not `str` or `Path`.
  744. Example usage with paths:
  745. ```python
  746. >>> # Filter only PDFs that are not hidden.
  747. >>> list(UploadingCheck.filter_repo_objects(
  748. ... ["aaa.PDF", "bbb.jpg", ".ccc.pdf", ".ddd.png"],
  749. ... allow_patterns=["*.pdf"],
  750. ... ignore_patterns=[".*"],
  751. ... ))
  752. ["aaa.pdf"]
  753. ```
  754. """
  755. allow_patterns = allow_patterns if allow_patterns else None
  756. ignore_patterns = ignore_patterns if ignore_patterns else None
  757. if isinstance(allow_patterns, str):
  758. allow_patterns = [allow_patterns]
  759. if isinstance(ignore_patterns, str):
  760. ignore_patterns = [ignore_patterns]
  761. if allow_patterns is not None:
  762. allow_patterns = [
  763. UploadingCheck._add_wildcard_to_directories(p)
  764. for p in allow_patterns
  765. ]
  766. if ignore_patterns is not None:
  767. ignore_patterns = [
  768. UploadingCheck._add_wildcard_to_directories(p)
  769. for p in ignore_patterns
  770. ]
  771. if key is None:
  772. def _identity(item: T):
  773. if isinstance(item, str):
  774. return item
  775. if isinstance(item, Path):
  776. return str(item)
  777. raise ValueError(
  778. f'Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.'
  779. )
  780. key = _identity # Items must be `str` or `Path`, otherwise raise ValueError
  781. for item in items:
  782. path = key(item)
  783. # Skip if there's an allowlist and path doesn't match any
  784. if allow_patterns is not None and not any(
  785. fnmatch(path, r) for r in allow_patterns):
  786. continue
  787. # Skip if there's a denylist and path matches any
  788. if ignore_patterns is not None and any(
  789. fnmatch(path, r) for r in ignore_patterns):
  790. continue
  791. yield item
  792. @staticmethod
  793. def _add_wildcard_to_directories(pattern: str):
  794. if pattern[-1] == '/':
  795. return pattern + '*'
  796. return pattern
  797. def create_repo(**kwargs):
  798. """
  799. 创建
  800. """
  801. return Hub().create_repo(**kwargs)
  802. def upload(**kwargs):
  803. """
  804. 上传
  805. """
  806. log.error("This function is not supported.Please use upload_file instead.")
  807. return None
  808. def upload_file(*,
  809. path_or_fileobj: Union[str, Path, bytes, BinaryIO],
  810. path_in_repo: str,
  811. repo_id: str,
  812. token: Union[str, None] = None,
  813. repo_type: Optional[str] = config.REPO_TYPE_MODEL,
  814. commit_message: Optional[str] = None,
  815. revision: Optional[str] = config.DEFAULT_REPOSITORY_REVISION,):
  816. """
  817. single file
  818. """
  819. return Hub().upload_file(path_or_fileobj=path_or_fileobj,
  820. path_in_repo=path_in_repo,
  821. repo_id=repo_id,
  822. token=token,
  823. repo_type=repo_type,
  824. commit_message=commit_message,
  825. revision=revision)
  826. def upload_folder(*,
  827. repo_id: str,
  828. folder_path: Union[str, Path, List[str], List[Path]] = None,
  829. path_in_repo: Optional[str] = '',
  830. commit_message: Optional[str] = None,
  831. token: Union[str, None] = None,
  832. repo_type: Optional[str] = config.REPO_TYPE_MODEL,
  833. allow_patterns: Optional[Union[List[str], str]] = None,
  834. ignore_patterns: Optional[Union[List[str], str]] = None,
  835. max_workers: int = config.DEFAULT_MAX_WORKERS,
  836. revision: Optional[str] = config.DEFAULT_REPOSITORY_REVISION,):
  837. """
  838. 上传
  839. """
  840. return Hub().upload_folder(
  841. repo_id,
  842. folder_path,
  843. path_in_repo,
  844. commit_message,
  845. token,
  846. repo_type,
  847. allow_patterns,
  848. ignore_patterns,
  849. max_workers,
  850. revision,)
  851. def file_exists(repo_id, filename, *args, **kwargs):
  852. """
  853. 检查云端文件存在与否
  854. """
  855. return Hub().file_exists(repo_id, filename, *args, **kwargs)