dataset.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. 本文件实现了请求产线任务
  10. Authors: suoyi@baidu.com
  11. Date: 2024/7/20
  12. """
  13. import json
  14. import requests
  15. from aistudio_sdk import config, log
  16. from baidubce.bce_client_configuration import BceClientConfiguration
  17. from baidubce.auth.bce_credentials import BceCredentials
  18. from baidubce.services.bos.bos_client import BosClient
  19. import os
  20. import threading
  21. from concurrent.futures import ThreadPoolExecutor
  22. from tqdm import tqdm
  23. from urllib.parse import urljoin
  24. from pathlib import Path
  25. import re
  26. class RequestDatasetException(Exception):
  27. """
  28. exception for requesting dataset server
  29. """
  30. pass
  31. MAX_WORKERS_FILE = os.path.expanduser("~/.download_max_workers")
  32. # 默认线程数
  33. DEFAULT_MAX_WORKERS = 6
  34. def get_max_workers():
  35. """max download worker"""
  36. try:
  37. with open(MAX_WORKERS_FILE, 'r') as f:
  38. return int(f.read().strip())
  39. except (Exception) as e:
  40. return DEFAULT_MAX_WORKERS
  41. def post_request_get_file_ids(url, datasetId):
  42. """file info"""
  43. data = {"datasetId": datasetId}
  44. response = requests.post(url, data=data)
  45. response.raise_for_status()
  46. result = response.json().get("result", {})
  47. file_ids = result.get("fileIds", [])
  48. return file_ids
  49. def load_token():
  50. """
  51. load
  52. """
  53. if not os.path.exists(config.TOKEN_FILE):
  54. return None
  55. with open(config.TOKEN_FILE, 'r') as f:
  56. return f.read().strip()
  57. def _header_fill(params=None, token=''):
  58. """
  59. 填充header
  60. """
  61. if token:
  62. auth = f'{token}'
  63. else:
  64. auth = f'{os.getenv("AISTUDIO_ACCESS_TOKEN", default="")}'
  65. headers = {
  66. 'Content-Type': 'application/json',
  67. 'Authorization': auth
  68. }
  69. if params:
  70. headers.update(params)
  71. return headers
  72. def get_file_url(host, datasetId, fileId):
  73. """get url"""
  74. path = f"/llm/files/datasets/{datasetId}/file/{fileId}/download"
  75. url = urljoin(host, path)
  76. token = load_token()
  77. print(token)
  78. if token is not None:
  79. headers = _header_fill(token=token)
  80. else:
  81. headers = _header_fill()
  82. response = requests.get(url, headers=headers)
  83. response.raise_for_status()
  84. print(response.json())
  85. return response.json()["result"]["fileUrl"]
  86. CHUNK_SIZE = 160 * 1024 * 1024 # 160MB
  87. def parse_filename_from_cd(cd_header):
  88. """filename"""
  89. if not cd_header:
  90. return None
  91. fname = re.findall('filename="?([^";]+)"?', cd_header)
  92. return fname[0] if fname else None
  93. def get_file_info(file_url):
  94. """获取文件大小和文件名"""
  95. r = requests.head(file_url, allow_redirects=True)
  96. r.raise_for_status()
  97. file_size = int(r.headers.get('Content-Length', 0))
  98. cd = r.headers.get("Content-Disposition", "")
  99. filename = parse_filename_from_cd(cd)
  100. if not filename:
  101. filename = os.path.basename(file_url.split("?")[0])
  102. return file_size, filename
  103. def download_chunk(file_url, start, end, local_path, pbar, lock):
  104. """download"""
  105. headers = {'Range': f"bytes={start}-{end}"}
  106. response = requests.get(file_url, headers=headers, stream=True)
  107. response.raise_for_status()
  108. with open(local_path, 'rb+') as f:
  109. f.seek(start)
  110. for chunk in response.iter_content(chunk_size=8192):
  111. if chunk:
  112. f.write(chunk)
  113. with lock:
  114. pbar.update(len(chunk))
  115. def download_file_multithreaded(file_url, local_dir, max_workers=None):
  116. """multi thread"""
  117. if max_workers is None:
  118. max_workers = get_max_workers()
  119. # Step 1: Get file size and filename
  120. file_size, filename = get_file_info(file_url)
  121. local_path = os.path.join(local_dir, filename)
  122. os.makedirs(local_dir, exist_ok=True)
  123. # Step 2: Create empty file if not exists
  124. if not os.path.exists(local_path):
  125. with open(local_path, 'wb') as f:
  126. f.truncate(file_size)
  127. # Step 3: Calculate chunks
  128. chunks = []
  129. for i in range(0, file_size, CHUNK_SIZE):
  130. start = i
  131. end = min(i + CHUNK_SIZE - 1, file_size - 1)
  132. chunks.append((start, end))
  133. # Step 4: Prepare download tasks
  134. from threading import Lock
  135. pbar = tqdm(total=file_size, unit='B', unit_scale=True, desc=filename)
  136. lock = Lock()
  137. with ThreadPoolExecutor(max_workers=max_workers) as executor:
  138. futures = []
  139. for start, end in chunks:
  140. # 检查是否已下载该块
  141. if os.path.exists(local_path):
  142. current_size = os.path.getsize(local_path)
  143. if current_size >= end + 1:
  144. pbar.update(end - start + 1)
  145. continue
  146. futures.append(
  147. executor.submit(download_chunk, file_url, start, end, local_path, pbar, lock)
  148. )
  149. for f in futures:
  150. f.result()
  151. pbar.close()
  152. def download_datasets(datasetId, local_dir=None):
  153. """old dataset"""
  154. if local_dir is None:
  155. local_dir = os.getenv("HOME")
  156. host = os.getenv("STUDIO_GIT_HOST", default=config.STUDIO_MODEL_API_URL_PREFIX_DEFAULT)
  157. url = f"{host}/studio/dataset/detail"
  158. download_all_files(url, host, datasetId, local_dir)
  159. def download_all_files(url, host, datasetId, localDir):
  160. """
  161. all
  162. """
  163. file_ids = post_request_get_file_ids(url, datasetId)
  164. os.makedirs(localDir, exist_ok=True)
  165. tasks = []
  166. pbar_lock = threading.Lock()
  167. with ThreadPoolExecutor(max_workers=4) as executor:
  168. for fileId in file_ids:
  169. file_url = get_file_url(host, datasetId, fileId)
  170. tasks.append(executor.submit(download_file_multithreaded, file_url, localDir, pbar_lock))
  171. for task in tasks:
  172. task.result()
  173. def bos_acl_dataset_file(
  174. token: str,
  175. bucket_name=None
  176. ):
  177. """
  178. 申请ak/sk
  179. response:
  180. {
  181. "logId": "",
  182. "errorCode": 0,
  183. "errorMsg": "",
  184. "timestamp": 0,
  185. "result": {
  186. "accessKeyId": "",
  187. "secretAccessKey": "",
  188. "sessionToken": "",
  189. "fileKey": "",
  190. "serverTime": 0,
  191. "expiresIn": 0,
  192. "endpoint": "",
  193. "bucketName": ""
  194. }
  195. }
  196. """
  197. url = f"{config.STUDIO_MODEL_API_URL_PREFIX_DEFAULT}/llm/files/acl"
  198. headers = {
  199. "Authorization": f"{token}",
  200. "Content-Type": "application/json"
  201. }
  202. params = {}
  203. if bucket_name:
  204. params["bucketName"] = bucket_name
  205. response = requests.get(url, headers=headers, params=params)
  206. if response.status_code == 200:
  207. return response.json()
  208. else:
  209. raise RequestDatasetException(f"Failed to get bos acl: {response.text}")
  210. def add_file_with_retry(token: str, file_origin_name: str, file_key: str, bucket_name=None, file_abs=None):
  211. """
  212. 上传文件到指定的bucket,并返回文件ID。
  213. """
  214. for i in range(3):
  215. try:
  216. file_id = add_file(token, file_origin_name, file_key, bucket_name, file_abs)
  217. return file_id
  218. except RequestDatasetException as e:
  219. log.error(f"add file 失败,重试第{i+1}次")
  220. log.error(e)
  221. def add_file(token: str, file_origin_name: str, file_key: str, bucket_name=None, file_abs=None):
  222. """
  223. 上传文件到指定的bucket,并返回文件ID。
  224. Args:
  225. token (str): 认证token。
  226. file_origin_name (str): 文件的原始名称。
  227. file_key (str): 文件在存储中的键值。
  228. bucket_name (str, optional): 如果提供,则上传到此bucket,否则使用默认bucket。
  229. file_abs (str, optional): 文件的绝对路径,可选。
  230. Returns:
  231. dict: 包含操作结果的字典,其中包括logId, errorCode, errorMsg, timestamp和result(包含fileId)。
  232. Raises:
  233. HTTPError: 如果请求失败,抛出异常。
  234. """
  235. log.debug("add file..")
  236. url = f"{config.STUDIO_MODEL_API_URL_PREFIX_DEFAULT}/llm/files/addfile"
  237. headers = {
  238. "Authorization": f"{token}",
  239. "Content-Type": "application/json"
  240. }
  241. data = {
  242. "fileOriginName": file_origin_name,
  243. "fileKey": file_key,
  244. "bucketName": bucket_name,
  245. "fileAbs": file_abs
  246. }
  247. response = requests.post(url, headers=headers, json=data)
  248. if response.status_code == 200:
  249. if response.json().get("errorCode") == 0:
  250. log.debug(f"add file success")
  251. result = response.json()
  252. file_id = result.get("result", {}).get("fileId")
  253. return file_id
  254. else:
  255. log.error("落库失败")
  256. log.error(f"add file failed, response: {data} {response.text}")
  257. return None
  258. else:
  259. raise RequestDatasetException(f"Failed to add file: {response.text}")
  260. def create_dataset_with_retry(token: str, dataset_name: str, file_ids: list,
  261. dataset_type=1, dataset_abs="", dataset_license=1):
  262. """
  263. 创建一个新的数据集,并返回数据集ID。
  264. """
  265. for i in range(3):
  266. try:
  267. dataset_id = create_dataset(token, dataset_name, file_ids, dataset_type, dataset_abs, dataset_license)
  268. return dataset_id
  269. except RequestDatasetException as e:
  270. log.error(f"create dataset 失败,重试第{i+1}次")
  271. log.error(e)
  272. def create_dataset(token: str, dataset_name: str, file_ids: list, dataset_type=1, dataset_abs="", dataset_license=1):
  273. """
  274. 创建一个新的数据集,并返回数据集ID。
  275. Args:
  276. token (str): 认证token。
  277. dataset_name (str): 数据集的名称。
  278. file_ids (list of int): 包含在数据集中的文件ID列表。
  279. dataset_type (int, optional): 数据集的类型,1 表示私有,2 表示公开。默认为0(私有)。
  280. dataset_abs (str, optional): 数据集的简介,可选。
  281. Returns:
  282. dict: 包含操作结果的字典,其中包括logId, errorCode, errorMsg, timestamp和result(包含datasetId)。
  283. None: 如果请求失败,返回None。
  284. """
  285. url = f"{config.STUDIO_MODEL_API_URL_PREFIX_DEFAULT}/llm/files/datasets"
  286. headers = {
  287. "Authorization": f"{token}",
  288. "Content-Type": "application/json"
  289. }
  290. data = {
  291. "datasetName": dataset_name,
  292. "datasetAbs": dataset_abs,
  293. "fileIds": file_ids,
  294. "datasetType": dataset_type,
  295. "protocolId": dataset_license
  296. }
  297. response = requests.post(url, headers=headers, json=data)
  298. if response.status_code == 200:
  299. log.debug(f"add file success")
  300. if response.json().get("errorCode") == 0:
  301. result = response.json()
  302. dataset_id = result.get("result", {}).get("datasetId")
  303. return dataset_id
  304. else:
  305. log.error(f"数据集创建失败:{response.json().get('errorMsg')}")
  306. log.debug(f"add file failed, response: {data} {response.text}")
  307. return None
  308. else:
  309. raise RequestDatasetException(f"Failed to create dataset: {response.text}")
  310. def add_files_to_dataset_with_retry(token: str, dataset_id: int, file_ids: list):
  311. """
  312. 向指定的数据集中添加文件。
  313. """
  314. for i in range(3):
  315. try:
  316. result = add_files_to_dataset(token, dataset_id, file_ids)
  317. return result
  318. except RequestDatasetException as e:
  319. log.error(f"add file to dataset 失败,重试第{i+1}次")
  320. log.error(e)
  321. def add_files_to_dataset(token: str, dataset_id: int, file_ids: list):
  322. """
  323. 向指定的数据集中添加文件。
  324. Args:
  325. token (str): 认证token。
  326. dataset_id (int): 数据集的ID。
  327. file_ids (list of int): 需要添加到数据集的文件ID列表。
  328. Returns:
  329. dict: 包含操作结果的字典,其中包括logId, errorCode, errorMsg, timestamp和result。
  330. None: 如果请求失败,返回None。
  331. """
  332. url = f"{config.STUDIO_MODEL_API_URL_PREFIX_DEFAULT}/llm/files/datasets/{dataset_id}/addfile"
  333. headers = {
  334. "Authorization": f"{token}",
  335. "Content-Type": "application/json"
  336. }
  337. data = {
  338. "fileIds": file_ids
  339. }
  340. response = requests.post(url, headers=headers, json=data)
  341. if response.status_code == 200:
  342. if response.json().get("errorCode") == 0:
  343. log.info(f"向数据集[{dataset_id}]中添加文件成功!")
  344. log.debug(f"向数据集[{dataset_id}]中添加文件成功[{file_ids}]")
  345. return response.json()
  346. else:
  347. log.error(f"添加文件失败: {response.json().get('errorMsg')}")
  348. log.debug(f"add file failed, response: {data} {response.text}")
  349. return None
  350. else:
  351. # log.error(f"Failed to add files to dataset: {response.text}")
  352. raise RequestDatasetException(f"Failed to add files to dataset: {response.text}")