hub.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2023 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. 本文件实现了请求hub
  10. Authors: xiangyiqing(xiangyiqing@baidu.com)
  11. Date: 2023/07/24
  12. """
  13. import base64
  14. import os
  15. import time
  16. import tqdm
  17. import json
  18. import requests
  19. from urllib.parse import quote
  20. from aistudio_sdk.constant.err_code import ErrorEnum
  21. from aistudio_sdk import config, log
  22. from aistudio_sdk.utils.util import err_resp
  23. from aistudio_sdk.utils.util import gen_ISO_format_datestr
  24. from aistudio_sdk.utils.util import file_to_base64, thread_executor
  25. from aistudio_sdk.utils.util import create_sha256_file_and_encode_base64
  26. from aistudio_sdk.constant.version import VERSION
  27. from aistudio_sdk.errors import RequestError
  28. CONNECTION_RETRY_TIMES = config.CONNECTION_RETRY_TIMES
  29. CONNECTION_TIMEOUT = config.CONNECTION_TIMEOUT
  30. CONNECTION_TIMEOUT_DOWNLOAD = config.CONNECTION_TIMEOUT_DOWNLOAD
  31. CONNECTION_TIMEOUT_UPLOAD = config.CONNECTION_TIMEOUT_UPLOAD
  32. #################### AIStudio 云端模型库 API ####################
  33. def _request_aistudio_hub(method, url, headers, data):
  34. """
  35. request aistudio hub
  36. """
  37. for _ in range(CONNECTION_RETRY_TIMES):
  38. try:
  39. response = requests.request(method, url, headers=headers,
  40. data=data, timeout=CONNECTION_TIMEOUT)
  41. return response.json()
  42. except requests.exceptions.JSONDecodeError:
  43. err_msg = "Response body does not contain valid json: {}".format(response.text)
  44. biz_code = response.status_code
  45. log.debug(err_msg)
  46. return err_resp(ErrorEnum.INTERNAL_ERROR.code,
  47. err_msg[:500],
  48. biz_code)
  49. def request_aistudio_hub(**kwargs):
  50. """
  51. 请求AIStudio hub: 模型库
  52. """
  53. headers = _header_fill(token=kwargs['token'])
  54. kwargs.pop('token')
  55. url = "{}{}".format(
  56. os.getenv("STUDIO_MODEL_API_URL_PREFIX", default=config.STUDIO_MODEL_API_URL_PREFIX_DEFAULT),
  57. config.HUB_URL
  58. )
  59. body = {k: v for k, v in kwargs.items()}
  60. log.debug(f"request_aistudio_hub url: {url}")
  61. log.debug(f"request_aistudio_hub body: {body}")
  62. payload = json.dumps(body)
  63. resp = _request_aistudio_hub('POST', url, headers, payload)
  64. return resp
  65. def request_aistudio_app_service(**kwargs):
  66. """
  67. 请求AIStudio hub:应用库
  68. """
  69. headers = _header_fill(token=kwargs['token'])
  70. kwargs.pop('token')
  71. url = "{}{}".format(
  72. os.getenv("STUDIO_MODEL_API_URL_PREFIX", default=config.STUDIO_MODEL_API_URL_PREFIX_DEFAULT),
  73. config.APP_SERVICE_URL
  74. )
  75. body = {k: v for k, v in kwargs.items()}
  76. log.debug(f"request_aistudio_app_service url: {url}")
  77. log.debug(f"request_aistudio_app_service body: {body}")
  78. # log.debug(f"request_aistudio_app_service headers: {headers}")
  79. payload = json.dumps(body)
  80. resp = _request_aistudio_hub('POST', url, headers, payload)
  81. return resp
  82. def request_aistudio_repo_visible(**kwargs):
  83. """
  84. 请求AIStudio hub 查看repo可见权限
  85. """
  86. headers = _header_fill(token=kwargs['token'])
  87. url = "{}{}".format(
  88. os.getenv("STUDIO_MODEL_API_URL_PREFIX", default=config.STUDIO_MODEL_API_URL_PREFIX_DEFAULT),
  89. config.HUB_URL_VISIBLE_CHECK
  90. )
  91. url = url + f"?repoId={quote(kwargs['repoId'], safe='')}&authorization=1"
  92. method = 'GET'
  93. try:
  94. err_msg = ''
  95. response = requests.request(method, url, headers=headers,
  96. timeout=CONNECTION_TIMEOUT)
  97. return response.json()
  98. except requests.exceptions.JSONDecodeError:
  99. err_msg = "Response body does not contain valid json: {}".format(response)
  100. biz_code = response.status_code
  101. return err_resp(ErrorEnum.INTERNAL_ERROR.code,
  102. err_msg[:500],
  103. biz_code)
  104. #################### AIStudio Gitea API ####################
  105. def _request_gitea(method, url, headers, data):
  106. """
  107. request gitea
  108. """
  109. for _ in range(CONNECTION_RETRY_TIMES):
  110. session = requests.Session()
  111. response = session.request(method, url, headers=headers, data=data, timeout=CONNECTION_TIMEOUT)
  112. session.close()
  113. if response.status_code not in (200, 201):
  114. log.debug(f"response: {response.text} {response.status_code}")
  115. log.info(f"potential network problem while request{url}:{response}")
  116. extra_msg = "[仓库或分支不存在]" if response.status_code == 404 else ""
  117. return err_resp(ErrorEnum.GITEA_DOWNLOAD_FILE_FAILED.code if
  118. method == "GET" else ErrorEnum.GITEA_UPLOAD_FILE_FAILED.code,
  119. response.content.decode()[:500] + extra_msg,
  120. biz_code=response.status_code)
  121. else:
  122. return response.json()
  123. def timing_decorator(func):
  124. """
  125. time cost decorator
  126. """
  127. def wrapper(*args, **kwargs):
  128. start_time = time.time()
  129. result = func(*args, **kwargs)
  130. end_time = time.time()
  131. elapsed_time = end_time - start_time
  132. print(f"{func.__name__} done, time cost: {elapsed_time:.2f}s")
  133. return result
  134. return wrapper
  135. @timing_decorator
  136. def _upload(method, url, headers, data):
  137. """
  138. _upload proc
  139. """
  140. session = requests.Session()
  141. response = session.request(method, url, headers=headers, data=data,
  142. stream=True, timeout=CONNECTION_TIMEOUT_UPLOAD)
  143. session.close()
  144. if response.status_code not in (200, 201):
  145. return err_resp(ErrorEnum.GITEA_UPLOAD_FILE_FAILED.code,
  146. response.content[:500],
  147. biz_code=response.status_code)
  148. else:
  149. return response.json()
  150. @timing_decorator
  151. def _download(url, download_path, headers):
  152. """
  153. Params
  154. :url: http url
  155. :download_path: download path
  156. :headers: headers
  157. Returns
  158. file
  159. """
  160. # 默认allow_redirects=True,即自动重定向,如果是LFS文件会直接从BOS下载
  161. response = requests.request('GET', url, stream=True, headers=headers,
  162. timeout=CONNECTION_TIMEOUT_DOWNLOAD)
  163. if response.status_code == 200:
  164. ret = {}
  165. elif response.status_code == 404:
  166. try:
  167. message = response.json()["message"]
  168. except requests.exceptions.JSONDecodeError:
  169. message = response.content.decode()
  170. ret = err_resp(ErrorEnum.FILE_NOT_FOUND.code,
  171. message,
  172. response.status_code)
  173. else:
  174. ret = err_resp(ErrorEnum.GITEA_DOWNLOAD_FILE_FAILED.code,
  175. f'Download failed, response code: {response.status_code}',
  176. response.status_code)
  177. total_size = int(response.headers.get('content-length', 0))
  178. block_size = 1024 * 100
  179. progress_bar = tqdm.tqdm(total=total_size, ncols=50, unit='iB', unit_scale=True,
  180. desc='Downloading file')
  181. with open(download_path, 'wb') as file:
  182. for data in response.iter_content(block_size):
  183. progress_bar.update(len(data))
  184. file.write(data)
  185. progress_bar.close()
  186. if total_size != 0 and progress_bar.n != total_size:
  187. print("ERROR, something went wrong")
  188. return ret
  189. def request_aistudio_git_download(url, download_path, token):
  190. """
  191. 请求AIStudio gitea文件下载
  192. """
  193. headers = _header_fill(token=token)
  194. res = _download(url, download_path, headers)
  195. return res
  196. def request_aistudio_git_file_info(call_host, user_name, repo_name, file_path,
  197. revision, token):
  198. """
  199. 请求AIStudio gitea 文件info
  200. GET /api/v1/repos/{owner}/{repo}/contents/{filepath} 返回的文件的数据、大小、编码等metadata信息+文件内容,或者文件夹中的文件列表
  201. """
  202. # 构建查询url
  203. url = f"{call_host}/api/v1/repos/{quote(user_name, safe='')}/" \
  204. f"{quote(repo_name, safe='')}/contents/{quote(file_path, safe='')}"
  205. if revision != 'master':
  206. url += f"?ref={quote(revision, safe='')}"
  207. headers = _header_fill(token=token)
  208. res = _request_gitea('GET', url, headers, "")
  209. log.debug(f"...result of GET /contents/{file_path}: {res}")
  210. return res
  211. def request_aistudio_git_file_type(call_host, user_name, repo_name, revision,
  212. path_in_repo, token):
  213. """
  214. 请求AIStudio gitea 确认文件类型
  215. """
  216. headers = _header_fill(token=token)
  217. url = f"{call_host}/{quote(user_name, safe='')}/{quote(repo_name, safe='')}/preupload/{quote(revision, safe='')}"
  218. params = {
  219. "files": [{
  220. "path": path_in_repo # 远程文件路径(相对于仓库根路径)
  221. }]
  222. }
  223. payload = json.dumps(params)
  224. result = _request_gitea('POST', url, headers, data=payload)
  225. log.debug(f"...result of POST /preupload: {url}: {result}")
  226. if 'error_code' in result:
  227. res = result
  228. elif 'files' not in result or not result['files'] or 'lfs' not in result['files'][0]:
  229. res = err_resp(ErrorEnum.GITEA_FAILED.code,
  230. str(result)[:500])
  231. else:
  232. res = {
  233. 'is_lfs': result['files'][0]['lfs']
  234. }
  235. return res
  236. def request_aistudio_git_files_type(call_host, user_name, repo_name, revision,
  237. path_in_repo_list, token):
  238. """
  239. 批量请求 AIStudio gitea 确认多个文件是否为 LFS 类型
  240. :param path_list: List[str],每个元素是 repo 内的相对路径
  241. :return: Dict[str, bool],key 是路径,value 是是否为 LFS
  242. """
  243. headers = _header_fill(token=token)
  244. url = (f"{call_host}/api/v1/repos/"
  245. f"{quote(user_name, safe='')}/{quote(repo_name, safe='')}/preupload/{quote(revision, safe='')}")
  246. params = {
  247. "files": [{"path": path} for path in path_in_repo_list]
  248. }
  249. payload = json.dumps(params)
  250. result = _request_gitea('POST', url, headers, data=payload)
  251. log.debug(f"...result of POST /preupload: {url}: {result}")
  252. if 'error_code' in result:
  253. raise ValueError(f"preupload fail, there is error_code in result")
  254. elif 'files' not in result or not isinstance(result['files'], list):
  255. raise ValueError(f"preupload fail, wrong result format")
  256. lfs_map = {}
  257. for file_info in result['files']:
  258. path = file_info.get('path')
  259. is_lfs = file_info.get('lfs', False)
  260. if path:
  261. lfs_map[path] = is_lfs
  262. return lfs_map
  263. def _parse_sts_token(upload_section: dict) -> dict:
  264. """
  265. 解析sts_token
  266. "upload": {
  267. "href": "https://some-download.com",
  268. "header": {
  269. "Key": "value"
  270. },
  271. "sts_token": {
  272. "bosHost":""
  273. "bucketName": "",
  274. "key":"",
  275. "accessKeyId":"",
  276. "secretAccessKey":"",
  277. "sessionToken":"",
  278. "createTime":"",
  279. "expiration":""
  280. }
  281. "expires_at": "2016-11-10T15:29:07Z"
  282. }
  283. """
  284. sts_token = upload_section.get("sts_token", {})
  285. if sts_token and sts_token.get("accessKeyId"):
  286. return {
  287. "bos_host": sts_token.get("bosHost"),
  288. "bucket_name": sts_token.get("bucketName"),
  289. "key": sts_token.get("key"),
  290. "access_key_id": sts_token.get("accessKeyId"),
  291. "secret_access_key": sts_token.get("secretAccessKey"),
  292. "session_token": sts_token.get("sessionToken"),
  293. "expiration": sts_token.get("expiration")
  294. }
  295. return {}
  296. def request_aistudio_git_upload_access(call_host, user_name, repo_name, revision, file_size,
  297. sha256, token):
  298. """
  299. 请求AIStudio gitea 申请上传LFS文件.
  300. 只支持单文件
  301. """
  302. params = {
  303. 'Content-Type': 'application/vnd.git-lfs+json; charset=utf-8',
  304. 'Accept': 'application/vnd.git-lfs+json'
  305. }
  306. headers = _header_fill(params=params, token=token)
  307. url = f"{call_host}/{quote(user_name, safe='')}/{quote(repo_name, safe='')}.git/info/lfs/objects/batch"
  308. params = {
  309. "operation": "upload", # 申请动作为上传
  310. "objects": [
  311. {
  312. "oid": sha256, # SHA256哈希
  313. "size": file_size # 单位byte
  314. }
  315. ],
  316. "transfers": [
  317. "lfs-standalone-file", "basic"
  318. ],
  319. "ref": {
  320. "name": f"refs/heads/{revision}" # 分支
  321. },
  322. "hash_algo": "sha256"
  323. }
  324. payload = json.dumps(params)
  325. result = _request_gitea('POST', url, headers, payload)
  326. log.debug(f"...result of POST /batch: {result}")
  327. if 'error_code' in result:
  328. res = result
  329. elif 'objects' not in result or not result['objects']:
  330. res = err_resp(ErrorEnum.GITEA_FAILED.code,
  331. str(result)[:500])
  332. else:
  333. tmp = result['objects'][0]
  334. # 已经存在的文件,不需要上传,actions为空
  335. res = {
  336. 'upload': True if 'actions' in tmp and 'upload' in tmp['actions'] else False,
  337. 'upload_href': tmp['actions']['upload']['href'] if 'actions' in tmp else '',
  338. 'sts_token': _parse_sts_token(tmp['actions']['upload']) if 'actions' in tmp else {},
  339. 'verify_href': tmp['actions']['verify']['href'] if 'actions' in tmp else ''
  340. }
  341. return res
  342. @timing_decorator
  343. def _lfs_upload(url, path_or_fileobj, headers):
  344. """
  345. 上传LFS文件到bos
  346. """
  347. with open(path_or_fileobj, 'rb') as file:
  348. response = requests.request('PUT', url, headers=headers, data=file,
  349. timeout=CONNECTION_TIMEOUT_UPLOAD, stream=True)
  350. return {'Content-Md5': response.headers['Content-Md5']}
  351. def request_bos_upload(url, path_or_fileobj):
  352. """
  353. 上传LFS文件到bos
  354. """
  355. params = {'Content-Type': 'application/octet-stream'}
  356. headers = _header_fill(params=params, token='')
  357. return _lfs_upload(url, path_or_fileobj, headers)
  358. def get_exist_file_old_sha(info_res):
  359. """
  360. 解析info_res
  361. """
  362. if 'error_code' in info_res and info_res['error_code'] != ErrorEnum.SUCCESS.code:
  363. return ''
  364. elif not info_res or 'sha' not in info_res:
  365. return ''
  366. else:
  367. old_sha = info_res['sha']
  368. return old_sha
  369. def request_aistudio_git_upload_pointer(call_host, user_name, repo_name, revision, commit_message,
  370. sha256, file_size, path_in_repo, token):
  371. """
  372. 请求AIStudio gitea 上传LFS指针文件(到仓库)
  373. """
  374. # 检查指针文件是否已存在,存在的话,要调用更新接口
  375. info_res = request_aistudio_git_file_info(call_host, user_name, repo_name, path_in_repo,
  376. revision, token)
  377. old_sha = get_exist_file_old_sha(info_res)
  378. if old_sha == '':
  379. method = 'POST'
  380. else:
  381. # 文件已存在,需要调用PUT接口更新
  382. method = 'PUT'
  383. headers = _header_fill(token=token)
  384. url = f"{call_host}/api/v1/repos/{quote(user_name, safe='')}/" \
  385. f"{quote(repo_name, safe='')}/contents/{quote(path_in_repo, safe='')}"
  386. params = {
  387. "branch": revision, # 提交的分支
  388. "new_branch": revision, # 提交的分支
  389. "content": create_sha256_file_and_encode_base64(sha256, file_size),
  390. "lfsPointer": True,
  391. "dates": {
  392. "author": gen_ISO_format_datestr(),
  393. "committer": gen_ISO_format_datestr()
  394. },
  395. "message": commit_message
  396. }
  397. if method == 'PUT':
  398. params['sha'] = old_sha
  399. payload = json.dumps(params)
  400. res = _request_gitea(method, url, headers, payload)
  401. return res
  402. def request_single_git_upload_pointer(call_host, user_name, repo_name, revision,
  403. sha256, file_size, path_in_repo, token):
  404. """
  405. 请求AIStudio gitea 上传LFS指针文件(到仓库)
  406. """
  407. # 检查指针文件是否已存在,存在的话,要调用更新接口
  408. info_res = request_aistudio_git_file_info(call_host, user_name, repo_name, path_in_repo,
  409. revision, token)
  410. old_sha = get_exist_file_old_sha(info_res)
  411. if old_sha == '':
  412. method = 'POST'
  413. else:
  414. # 文件已存在,需要调用PUT接口更新
  415. method = 'PUT'
  416. headers = _header_fill(token=token)
  417. url = f"{call_host}/api/v1/repos/{quote(user_name, safe='')}/" \
  418. f"{quote(repo_name, safe='')}/contents/{quote(path_in_repo, safe='')}"
  419. params = {
  420. "branch": revision, # 提交的分支
  421. "new_branch": revision, # 提交的分支
  422. "content": create_sha256_file_and_encode_base64(sha256, file_size),
  423. "lfsPointer": True,
  424. # "dates": {
  425. # "author": gen_ISO_format_datestr(),
  426. # "committer": gen_ISO_format_datestr()
  427. # },
  428. # "message": commit_message
  429. }
  430. if method == 'PUT':
  431. params['sha'] = old_sha
  432. payload = json.dumps(params)
  433. res = _request_gitea(method, url, headers, payload)
  434. return res
  435. def request_aistudio_git_upload_common(call_host, user_name, repo_name, revision,
  436. commit_message,
  437. path_or_fileobj, path_in_repo, token):
  438. """
  439. 请求AIStudio gitea 上传普通文件(到仓库)
  440. """
  441. # 检查文件是否已存在,存在的话,要调用更新接口
  442. info_res = request_aistudio_git_file_info(call_host, user_name, repo_name, path_in_repo,
  443. revision, token)
  444. old_sha = get_exist_file_old_sha(info_res)
  445. if old_sha == '':
  446. method = 'POST'
  447. else:
  448. # 文件已存在,需要调用PUT接口更新
  449. method = 'PUT'
  450. url = f"{call_host}/api/v1/repos/{quote(user_name, safe='')}/" \
  451. f"{quote(repo_name, safe='')}/contents/{quote(path_in_repo, safe='')}"
  452. headers = _header_fill(token=token)
  453. base64_data = file_to_base64(path_or_fileobj)
  454. params = {
  455. "branch": revision, # 提交的分支
  456. "new_branch": revision, # 提交的分支
  457. "content": base64_data,
  458. "lfs": False,
  459. "dates": {
  460. "author": gen_ISO_format_datestr(),
  461. "committer": gen_ISO_format_datestr()
  462. },
  463. "message": commit_message
  464. }
  465. if method == 'PUT':
  466. params['sha'] = old_sha
  467. payload = json.dumps(params)
  468. res = _upload(method, url, headers, payload)
  469. return res
  470. def request_single_git_upload_common(call_host, user_name, repo_name, revision,
  471. path_or_fileobj, path_in_repo, token):
  472. """
  473. 请求AIStudio gitea 上传普通文件(到仓库)
  474. """
  475. # 检查文件是否已存在,存在的话,要调用更新接口
  476. info_res = request_aistudio_git_file_info(call_host, user_name, repo_name, path_in_repo,
  477. revision, token)
  478. old_sha = get_exist_file_old_sha(info_res)
  479. if old_sha == '':
  480. method = 'POST'
  481. else:
  482. # 文件已存在,需要调用PUT接口更新
  483. method = 'PUT'
  484. url = f"{call_host}/api/v1/repos/{quote(user_name, safe='')}/" \
  485. f"{quote(repo_name, safe='')}/contents/{quote(path_in_repo, safe='')}"
  486. headers = _header_fill(token=token)
  487. base64_data = file_to_base64(path_or_fileobj)
  488. params = {
  489. # "branch": revision, # 提交的分支
  490. # "new_branch": revision, # 提交的分支
  491. "content": base64_data,
  492. "lfs": False,
  493. # "dates": {
  494. # "author": gen_ISO_format_datestr(),
  495. # "committer": gen_ISO_format_datestr()
  496. # },
  497. # "message": commit_message
  498. }
  499. if method == 'PUT':
  500. params['sha'] = old_sha
  501. payload = json.dumps(params)
  502. res = _upload(method, url, headers, payload)
  503. return res
  504. def request_aistudio_verify_lfs_file(call_host, oid: str, size: int, token=''):
  505. """
  506. param
  507. call_host: verify url
  508. oid: sha256, without sha256prefix
  509. size: file size
  510. """
  511. headers = {
  512. 'Content-Type': 'application/vnd.git-lfs+json',
  513. 'Accept': 'application/vnd.git-lfs+json'
  514. }
  515. params = {
  516. "oid": oid,
  517. "size": size
  518. }
  519. header = _header_fill(headers, token=token)
  520. res = requests.request("POST", call_host, headers=header, json=params, data=json.dumps(params))
  521. log.debug(f"...result of POST /verify: {res.text}")
  522. if res.status_code not in (200, 201):
  523. return err_resp(ErrorEnum.GITEA_UPLOAD_FILE_FAILED.code,
  524. res.text,
  525. biz_code=res.status_code)
  526. else:
  527. return res.json()
  528. def _header_fill(params=None, token=''):
  529. """
  530. 填充header
  531. """
  532. if token:
  533. auth = f'token {token}'
  534. else:
  535. auth = f'token {os.getenv("AISTUDIO_ACCESS_TOKEN", default="")}'
  536. headers = {
  537. 'Content-Type': 'application/json',
  538. 'Authorization': auth,
  539. 'SDK-Version': str(VERSION)
  540. }
  541. if params:
  542. headers.update(params)
  543. return headers
  544. def file_exists_and_sha(repo_id, revision, path, token):
  545. """检查目标文件是否存在,返回 (exists: bool, sha: str or None)"""
  546. host = os.getenv("STUDIO_GIT_HOST", default=config.STUDIO_GIT_HOST_DEFAULT)
  547. url = f"{host}/api/v1/repos/{repo_id}/contents/{path}?ref={revision}"
  548. headers = _header_fill(token=token)
  549. resp = requests.get(url, headers=headers)
  550. if resp.status_code == 200:
  551. return True, resp.json().get("sha")
  552. return False, None
  553. def get_lfs_pointer_content(sha256, size):
  554. """构造 LFS pointer 文件内容(纯文本)"""
  555. return f"version https://git-lfs.github.com/spec/v1\n" \
  556. f"oid sha256:{sha256}\nsize {size}\n"
  557. @thread_executor(disable_tqdm=False, max_workers=os.cpu_count())
  558. def prepare_entry(q, repo_id, revision, token):
  559. """
  560. prepare body data
  561. """
  562. path, local_path, is_lfs, sha256_input = q
  563. try:
  564. exists, sha = file_exists_and_sha(repo_id, revision, path, token)
  565. except Exception as e:
  566. log.error(f"{path} request git error,skip")
  567. return None
  568. if is_lfs:
  569. size = os.path.getsize(local_path)
  570. content_b64 = create_sha256_file_and_encode_base64(sha256_input or sha256_input, size)
  571. else:
  572. with open(local_path, "rb") as f:
  573. content_bytes = f.read()
  574. try:
  575. content_str = base64.b64encode(content_bytes)
  576. content_b64 = content_str.decode("utf-8")
  577. except UnicodeDecodeError:
  578. raise ValueError(f"❌ Non-UTF8 content in {local_path}. Consider LFS or base64 with encoding field.")
  579. entry = {
  580. "lfsPointer": is_lfs,
  581. "path": path,
  582. "content": content_b64,
  583. "operation": "update" if exists else "create"
  584. }
  585. if sha:
  586. entry["sha"] = sha
  587. return entry
  588. MAX_PAYLOAD_MB = 200
  589. MAX_PAYLOAD_BYTES = MAX_PAYLOAD_MB * 1024 * 1024
  590. def split_files_by_size(files, max_bytes=MAX_PAYLOAD_BYTES):
  591. """
  592. 将 files 拆分为多个分组,每组最大总大小不超过 max_bytes,每组个数不超过1k
  593. 要求每个 file 是一个 dict,包含 'size' 字段(单位字节)。
  594. """
  595. chunks = []
  596. current_chunk = []
  597. current_size = 0
  598. current_count = 0
  599. for file in files:
  600. file_size = file.get("size", 0)
  601. if file_size > max_bytes:
  602. raise ValueError(f"单个文件超过最大限制:{file_size / 1024 / 1024:.2f}MB")
  603. if (current_size + file_size > max_bytes or current_count + 1 >
  604. int(os.environ.get("MAX_COMMIT_FILE_COUNT", "500"))):
  605. chunks.append(current_chunk)
  606. current_chunk = [file]
  607. current_size = file_size
  608. current_count = 0
  609. else:
  610. current_chunk.append(file)
  611. current_size += file_size
  612. current_count += 1
  613. if current_chunk:
  614. chunks.append(current_chunk)
  615. return chunks
  616. def commit_files(log_list, repo_id, revision, commit_message, file_quads, token, author=None, committer=None):
  617. """
  618. commit
  619. """
  620. if author is None:
  621. author = {"name": "Auto Commit", "email": "auto@example.com"}
  622. if committer is None:
  623. committer = author
  624. log.info("calculate files")
  625. log.debug(f"cpu:{os.cpu_count()}")
  626. files = prepare_entry(file_quads, repo_id=repo_id, revision=revision, token=token)
  627. files = [item for item in files if item is not None]
  628. file_chunks = split_files_by_size(files)
  629. if len(file_chunks) > 1:
  630. log.info("files will be commited in multi batches")
  631. host = os.getenv("STUDIO_GIT_HOST", default=config.STUDIO_GIT_HOST_DEFAULT)
  632. headers = _header_fill(token=token)
  633. url = f"{host}/api/v1/repos/{repo_id}/contents"
  634. for i, chunk in enumerate(file_chunks, start=1):
  635. commit_message_current = commit_message if len(file_chunks) == 1 \
  636. else f"{commit_message} (part {i}/{len(file_chunks)})"
  637. chunk_payload = {
  638. "branch": revision,
  639. "message": commit_message_current,
  640. "author": author,
  641. "committer": committer,
  642. "files": chunk
  643. }
  644. resp = requests.post(url, headers=headers, json=chunk_payload)
  645. if resp.status_code // 100 == 2:
  646. print(resp.status_code)
  647. print(f"✅ Commit part {i} successful!")
  648. else:
  649. for entry in chunk:
  650. for path, local_path, is_lfs, sha256_input in file_quads:
  651. if entry['path'] == path:
  652. log_list.append((local_path, "commit fail"))
  653. print(f"❌ Commit part {i} failed: {resp.status_code}")
  654. print(resp.text)