util.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2023 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. 本文件实现了常用的工具函数
  10. Authors: xiangyiqing(xiangyiqing@baidu.com)
  11. Date: 2023/07/24
  12. """
  13. import tempfile
  14. import sys
  15. import os
  16. import io
  17. import re
  18. import base64
  19. import hashlib
  20. from datetime import datetime, timezone, timedelta
  21. import zipfile
  22. from aistudio_sdk import log
  23. from aistudio_sdk.errors import FileIntegrityError
  24. from aistudio_sdk.config import DEFAULT_MAX_WORKERS
  25. from functools import wraps
  26. from tqdm.auto import tqdm
  27. from concurrent.futures import ThreadPoolExecutor, as_completed
  28. from typing import List, Union, BinaryIO, Optional
  29. from aistudio_sdk.constant.version import VERSION
  30. from pathlib import Path
  31. class Dict(dict):
  32. """dict class"""
  33. def __getattr__(self, key):
  34. value = self.get(key, None)
  35. return Dict(value) if isinstance(value, dict) else value
  36. def __setattr__(self, key, value):
  37. self[key] = value
  38. def convert_to_dict_object(resp):
  39. """
  40. Params
  41. :resp: dict, response from AIStudio
  42. Rerurns
  43. AIStudio object
  44. """
  45. if isinstance(resp, dict):
  46. return Dict(resp)
  47. return resp
  48. def err_resp(sdk_code, msg, biz_code=None, log_id=None):
  49. """
  50. 构造错误响应信息。
  51. Params:
  52. sdk_code (str): SDK错误码,标识错误类型。
  53. msg (str): 错误描述信息。
  54. biz_code (str, optional): 业务层面的错误码,透传自上游接口。
  55. log_id (str, optional): 与错误相关的日志ID,透传自上游接口。
  56. Returns:
  57. dict: 格式化好的错误信息。
  58. """
  59. return {
  60. "error_code": sdk_code, # 错误码
  61. "error_msg": msg, # 错误消息
  62. "biz_code": biz_code, # 业务错误码
  63. "log_id": log_id # 日志ID
  64. }
  65. def is_valid_host(host):
  66. """检测host合法性"""
  67. # 去除可能的协议前缀 如http://、https://
  68. host = re.sub(r'^https?://', '', host, flags=re.IGNORECASE)
  69. result = is_valid_domain(host)
  70. # if not result:
  71. # host = re.sub(r'^http?://', '', host, flags=re.IGNORECASE)
  72. # result = is_valid_domain(host)
  73. return result
  74. def is_valid_domain(domain):
  75. """检测域名合法性"""
  76. return True
  77. # pattern = r"^(?!-)[A-Za-z0-9-]{1,63}(?<!-)(\.[A-Za-z]{2,})+$"
  78. # return re.match(pattern, domain) is not None
  79. def calculate_sha256(file_path):
  80. """将文件计算为sha256值"""
  81. sha256_hash = hashlib.sha256()
  82. with open(file_path, "rb") as file:
  83. # 逐块更新哈希值,以适应大型文件
  84. while True:
  85. data = file.read(65536) # 64K块大小
  86. if not data:
  87. break
  88. sha256_hash.update(data)
  89. return sha256_hash.hexdigest()
  90. def gen_ISO_format_datestr():
  91. """
  92. # 生成 ISO 8601日期时间格式
  93. # 例如"2023-09-12T11:29:45.703Z"
  94. """
  95. # 获取当前日期和时间
  96. zone = timezone(timedelta(hours=8))
  97. now = datetime.now(zone)
  98. # 使用strftime函数将日期和时间格式化为所需的字符串格式
  99. formatted_date = now.isoformat(timespec='milliseconds')
  100. return formatted_date
  101. def gen_MD5(file_path):
  102. """将文件计算为md5值"""
  103. md5_hash = hashlib.md5()
  104. try:
  105. with open(file_path, 'rb') as file:
  106. # 逐块读取文件并更新哈希对象
  107. while True:
  108. data = file.read(4096) # 读取4K字节数据块
  109. if not data:
  110. break
  111. md5_hash.update(data)
  112. except FileNotFoundError:
  113. print(f"The file '{file_path}' does not exist.")
  114. return None
  115. # 获取MD5哈希值的十六进制表示
  116. md5_hex = md5_hash.hexdigest()
  117. return md5_hex
  118. def gen_base64(original_string):
  119. """将字符串计算为base64"""
  120. # 将原始字符串编码为字节数组
  121. bytes_data = original_string.encode('utf-8')
  122. # 使用base64进行编码
  123. base64_encoded = base64.b64encode(bytes_data).decode('utf-8')
  124. return base64_encoded
  125. def create_sha256_file_and_encode_base64(sha256, size):
  126. """生成指定内容的文件并进行base64编码字符串返回"""
  127. content = f"version https://git-lfs.github.com/spec/v1\noid sha256:{sha256}\nsize {size}"
  128. with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.txt') as tmp:
  129. tmp.write(content)
  130. tmp_path = tmp.name
  131. log.debug(tmp_path)
  132. try:
  133. with open(tmp_path, 'rb') as f:
  134. encoded = base64.b64encode(f.read()).decode('utf-8')
  135. return encoded
  136. finally:
  137. if os.path.exists(tmp_path):
  138. os.remove(tmp_path)
  139. # name = 'sha256_value'
  140. # with open(name, 'w') as file:
  141. # file.write(content)
  142. #
  143. # ret = file_to_base64(name)
  144. # os.remove(name)
  145. # return ret
  146. def file_to_base64(filename):
  147. """读取文件内容并进行Base64编码"""
  148. with open(filename, "rb") as file:
  149. contents = file.read()
  150. encoded_contents = base64.b64encode(contents)
  151. return encoded_contents.decode('utf-8')
  152. def zip_dir(dirpath, out_full_name):
  153. """
  154. 压缩指定文件夹
  155. :param dirpath: 目标文件夹路径
  156. :param out_full_name: 压缩文件保存路径 xxxx.zip
  157. :return: 无
  158. """
  159. zip_obj = zipfile.ZipFile(out_full_name, "w", zipfile.ZIP_DEFLATED)
  160. for path, dirnames, filenames in os.walk(dirpath):
  161. # 去掉目标跟路径,只对目标文件夹下边的文件及文件夹进行压缩
  162. fpath = path.replace(dirpath, '')
  163. for filename in filenames:
  164. zip_obj.write(os.path.join(path, filename), os.path.join(fpath, filename))
  165. zip_obj.close()
  166. def compute_hash(file_path):
  167. """
  168. hash
  169. """
  170. BUFFER_SIZE = 1024 * 64 # 64k buffer size
  171. sha256_hash = hashlib.sha256()
  172. with open(file_path, 'rb') as f:
  173. while True:
  174. data = f.read(BUFFER_SIZE)
  175. if not data:
  176. break
  177. sha256_hash.update(data)
  178. return sha256_hash.hexdigest()
  179. def file_integrity_validation(file_path, expected_sha256):
  180. """Validate the file hash is expected, if not, delete the file
  181. Args:
  182. file_path (str): The file to validate
  183. expected_sha256 (str): The expected sha256 hash
  184. Raises:
  185. FileIntegrityError: If file_path hash is not expected.
  186. """
  187. file_sha256 = compute_hash(file_path)
  188. if not file_sha256 == expected_sha256:
  189. os.remove(file_path)
  190. msg = ('File %s integrity check failed, expected sha256 signature is %s, '
  191. 'actual is %s, the download may be incomplete, please try again.') % ( # noqa E501
  192. file_path, expected_sha256, file_sha256)
  193. log.error(msg)
  194. raise FileIntegrityError(msg)
  195. def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS,
  196. disable_tqdm: bool = False,
  197. tqdm_desc: str = None):
  198. """
  199. A decorator to execute a function in a threaded manner using ThreadPoolExecutor.
  200. Args:
  201. max_workers (int): The maximum number of threads to use.
  202. disable_tqdm (bool): disable progress bar.
  203. tqdm_desc (str): Desc of tqdm.
  204. Returns:
  205. function: A wrapped function that executes with threading and a progress bar.
  206. """
  207. def decorator(func):
  208. @wraps(func)
  209. def wrapper(iterable, *args, **kwargs):
  210. results = []
  211. # Create a tqdm progress bar with the total number of items to process
  212. with tqdm(
  213. unit_scale=True,
  214. unit_divisor=1024,
  215. initial=0,
  216. total=len(iterable),
  217. desc=tqdm_desc or f'Processing {len(iterable)} items',
  218. disable=disable_tqdm,
  219. ) as pbar:
  220. # Define a wrapper function to update the progress bar
  221. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  222. # Submit all tasks
  223. futures = {
  224. executor.submit(func, item, *args, **kwargs): item
  225. for item in iterable
  226. }
  227. # Update the progress bar as tasks complete
  228. for future in as_completed(futures):
  229. pbar.update(1)
  230. results.append(future.result())
  231. return results
  232. return wrapper
  233. return decorator
  234. def get_model_masked_directory(directory, model_id):
  235. """
  236. 目录
  237. """
  238. if sys.platform.startswith('win'):
  239. parts = directory.rsplit('\\', 2)
  240. else:
  241. parts = directory.rsplit('/', 2)
  242. # this is the actual directory the model files are located.
  243. masked_directory = os.path.join(parts[0], model_id.replace('.', '___'))
  244. return masked_directory
  245. def convert_patterns(raw_input: Union[str, List[str]]):
  246. """
  247. 处理规则
  248. """
  249. output = None
  250. if isinstance(raw_input, str):
  251. output = list()
  252. if ',' in raw_input:
  253. output = [s.strip() for s in raw_input.split(',')]
  254. else:
  255. output.append(raw_input.strip())
  256. elif isinstance(raw_input, list):
  257. output = list()
  258. for s in raw_input:
  259. if isinstance(s, str):
  260. if ',' in s:
  261. output.extend([ss.strip() for ss in s.split(',')])
  262. else:
  263. output.append(s.strip())
  264. return output
  265. def header_fill(params=None, token=''):
  266. """
  267. 填充header
  268. """
  269. if token:
  270. auth = f'token {token}'
  271. else:
  272. auth = f'token {os.getenv("AISTUDIO_ACCESS_TOKEN", default="")}'
  273. headers = {
  274. 'Content-Type': 'application/json',
  275. 'Authorization': auth,
  276. 'SDK-Version': str(VERSION)
  277. }
  278. if params:
  279. headers.update(params)
  280. return headers
  281. def extract_yaml_block(file_path):
  282. """
  283. 获取yaml
  284. """
  285. with open(file_path, 'r', encoding='utf-8') as f:
  286. content = f.read()
  287. # 提取 --- 和 --- 之间的内容(非贪婪匹配)
  288. match = re.search(r'^---\s*(.*?)\s*---', content, re.DOTALL | re.MULTILINE)
  289. if match:
  290. return match.group(1).strip()
  291. else:
  292. raise ValueError("未找到两个 '---' 分隔的 YAML 内容")
  293. def is_readme_md(file_path):
  294. """
  295. 判断文件名
  296. """
  297. file_name = os.path.basename(file_path)
  298. return file_name == 'README.md'
  299. def get_file_size(file_path_or_obj: Union[str, Path, bytes, BinaryIO]) -> int:
  300. """
  301. get size
  302. """
  303. if isinstance(file_path_or_obj, (str, Path)):
  304. file_path = Path(file_path_or_obj)
  305. return file_path.stat().st_size
  306. elif isinstance(file_path_or_obj, bytes):
  307. return len(file_path_or_obj)
  308. elif isinstance(file_path_or_obj, io.BufferedIOBase):
  309. current_position = file_path_or_obj.tell()
  310. file_path_or_obj.seek(0, os.SEEK_END)
  311. size = file_path_or_obj.tell()
  312. file_path_or_obj.seek(current_position)
  313. return size
  314. else:
  315. raise TypeError(
  316. 'Unsupported type: must be string, Path, bytes, or io.BufferedIOBase'
  317. )
  318. def get_file_hash(
  319. file_path_or_obj: Union[str, Path, bytes, BinaryIO],
  320. buffer_size_mb: Optional[int] = 1,
  321. tqdm_desc: Optional[str] = '[Calculating]',
  322. disable_tqdm: Optional[bool] = True,
  323. ) -> dict:
  324. """
  325. calculate hash
  326. """
  327. from tqdm.auto import tqdm
  328. file_size = get_file_size(file_path_or_obj)
  329. if file_size > 1024 * 1024 * 1024: # 1GB
  330. disable_tqdm = False
  331. name = 'Large File'
  332. if isinstance(file_path_or_obj, (str, Path)):
  333. path = file_path_or_obj if isinstance(
  334. file_path_or_obj, Path) else Path(file_path_or_obj)
  335. name = path.name
  336. tqdm_desc = f'[Validating Hash for {name}]'
  337. buffer_size = buffer_size_mb * 1024 * 1024
  338. file_hash = hashlib.sha256()
  339. chunk_hash_list = []
  340. progress = tqdm(
  341. total=file_size,
  342. initial=0,
  343. unit_scale=True,
  344. dynamic_ncols=True,
  345. unit='B',
  346. desc=tqdm_desc,
  347. disable=disable_tqdm,
  348. )
  349. if isinstance(file_path_or_obj, (str, Path)):
  350. with open(file_path_or_obj, 'rb') as f:
  351. while byte_chunk := f.read(buffer_size):
  352. chunk_hash_list.append(hashlib.sha256(byte_chunk).hexdigest())
  353. file_hash.update(byte_chunk)
  354. progress.update(len(byte_chunk))
  355. file_hash = file_hash.hexdigest()
  356. final_chunk_size = buffer_size
  357. elif isinstance(file_path_or_obj, bytes):
  358. file_hash.update(file_path_or_obj)
  359. file_hash = file_hash.hexdigest()
  360. chunk_hash_list.append(file_hash)
  361. final_chunk_size = len(file_path_or_obj)
  362. progress.update(final_chunk_size)
  363. elif isinstance(file_path_or_obj, io.BufferedIOBase):
  364. while byte_chunk := file_path_or_obj.read(buffer_size):
  365. chunk_hash_list.append(hashlib.sha256(byte_chunk).hexdigest())
  366. file_hash.update(byte_chunk)
  367. progress.update(len(byte_chunk))
  368. file_hash = file_hash.hexdigest()
  369. final_chunk_size = buffer_size
  370. else:
  371. progress.close()
  372. raise ValueError(
  373. 'Input must be str, Path, bytes or a io.BufferedIOBase')
  374. progress.close()
  375. return {
  376. 'file_path_or_obj': file_path_or_obj,
  377. 'file_hash': file_hash,
  378. 'file_size': file_size,
  379. 'chunk_size': final_chunk_size,
  380. 'chunk_nums': len(chunk_hash_list),
  381. 'chunk_hash_list': chunk_hash_list,
  382. }