file_download.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import copy
  3. import hashlib
  4. import io
  5. import os
  6. import shutil
  7. import tempfile
  8. import urllib
  9. import uuid
  10. from concurrent.futures import ThreadPoolExecutor
  11. from functools import partial
  12. from http.cookiejar import CookieJar
  13. from pathlib import Path
  14. from typing import Dict, List, Optional, Type, Union
  15. import requests
  16. from requests.adapters import Retry
  17. from tqdm.auto import tqdm
  18. from modelscope.hub.api import HubApi, ModelScopeConfig
  19. from modelscope.hub.constants import (
  20. API_FILE_DOWNLOAD_CHUNK_SIZE, API_FILE_DOWNLOAD_RETRY_TIMES,
  21. API_FILE_DOWNLOAD_TIMEOUT, FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
  22. MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB, TEMPORARY_FOLDER_NAME)
  23. from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
  24. DEFAULT_MODEL_REVISION,
  25. INTRA_CLOUD_ACCELERATION,
  26. REPO_TYPE_DATASET, REPO_TYPE_MODEL,
  27. REPO_TYPE_SUPPORT)
  28. from modelscope.utils.file_utils import (get_dataset_cache_root,
  29. get_model_cache_root)
  30. from modelscope.utils.logger import get_logger
  31. from .callback import ProgressCallback, TqdmCallback
  32. from .errors import FileDownloadError, InvalidParameter, NotExistError
  33. from .utils.caching import ModelFileSystemCache
  34. from .utils.utils import (file_integrity_validation, get_endpoint,
  35. model_id_to_group_owner_name)
  36. logger = get_logger()
  37. def model_file_download(
  38. model_id: str,
  39. file_path: str,
  40. revision: Optional[str] = DEFAULT_MODEL_REVISION,
  41. cache_dir: Optional[str] = None,
  42. user_agent: Union[Dict, str, None] = None,
  43. local_files_only: Optional[bool] = False,
  44. cookies: Optional[CookieJar] = None,
  45. local_dir: Optional[str] = None,
  46. ) -> Optional[str]: # pragma: no cover
  47. """Download from a given URL and cache it if it's not already present in the local cache.
  48. Given a URL, this function looks for the corresponding file in the local
  49. cache. If it's not there, download it. Then return the path to the cached
  50. file.
  51. Args:
  52. model_id (str): The model to whom the file to be downloaded belongs.
  53. file_path(str): Path of the file to be downloaded, relative to the root of model repo.
  54. revision(str, optional): revision of the model file to be downloaded.
  55. Can be any of a branch, tag or commit hash.
  56. cache_dir (str, Path, optional): Path to the folder where cached files are stored.
  57. user_agent (dict, str, optional): The user-agent info in the form of a dictionary or a string.
  58. local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
  59. local cached file if it exists. if `False`, download the file anyway even it exists.
  60. cookies (CookieJar, optional): The cookie of download request.
  61. local_dir (str, optional): Specific local directory path to which the file will be downloaded.
  62. Returns:
  63. string: string of local file or if networking is off, last version of
  64. file cached on disk.
  65. Raises:
  66. NotExistError: The file is not exist.
  67. ValueError: The request parameter error.
  68. Note:
  69. Raises the following errors:
  70. - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
  71. if `use_auth_token=True` and the token cannot be found.
  72. - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
  73. if ETag cannot be determined.
  74. - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  75. if some parameter value is invalid
  76. """
  77. return _repo_file_download(
  78. model_id,
  79. file_path,
  80. repo_type=REPO_TYPE_MODEL,
  81. revision=revision,
  82. cache_dir=cache_dir,
  83. user_agent=user_agent,
  84. local_files_only=local_files_only,
  85. cookies=cookies,
  86. local_dir=local_dir)
  87. def dataset_file_download(
  88. dataset_id: str,
  89. file_path: str,
  90. revision: Optional[str] = DEFAULT_DATASET_REVISION,
  91. cache_dir: Union[str, Path, None] = None,
  92. local_dir: Optional[str] = None,
  93. user_agent: Optional[Union[Dict, str]] = None,
  94. local_files_only: Optional[bool] = False,
  95. cookies: Optional[CookieJar] = None,
  96. ) -> str:
  97. """Download raw files of a dataset.
  98. Downloads all files at the specified revision. This
  99. is useful when you want all files from a dataset, because you don't know which
  100. ones you will need a priori. All files are nested inside a folder in order
  101. to keep their actual filename relative to that folder.
  102. An alternative would be to just clone a dataset but this would require that the
  103. user always has git and git-lfs installed, and properly configured.
  104. Args:
  105. dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
  106. file_path (str): The relative path of the file to download.
  107. revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
  108. commit hash. NOTE: currently only branch and tag name is supported
  109. cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset file will
  110. be save as cache_dir/dataset_id/THE_DATASET_FILES.
  111. local_dir (str, optional): Specific local directory path to which the file will be downloaded.
  112. user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
  113. local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
  114. local cached file if it exists.
  115. cookies (CookieJar, optional): The cookie of the request, default None.
  116. Raises:
  117. ValueError: the value details.
  118. Returns:
  119. str: Local folder path (string) of repo snapshot
  120. Note:
  121. Raises the following errors:
  122. - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
  123. if `use_auth_token=True` and the token cannot be found.
  124. - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
  125. ETag cannot be determined.
  126. - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  127. if some parameter value is invalid
  128. """
  129. return _repo_file_download(
  130. dataset_id,
  131. file_path,
  132. repo_type=REPO_TYPE_DATASET,
  133. revision=revision,
  134. cache_dir=cache_dir,
  135. user_agent=user_agent,
  136. local_files_only=local_files_only,
  137. cookies=cookies,
  138. local_dir=local_dir)
  139. def _repo_file_download(
  140. repo_id: str,
  141. file_path: str,
  142. *,
  143. repo_type: str = None,
  144. revision: Optional[str] = DEFAULT_MODEL_REVISION,
  145. cache_dir: Optional[str] = None,
  146. user_agent: Union[Dict, str, None] = None,
  147. local_files_only: Optional[bool] = False,
  148. cookies: Optional[CookieJar] = None,
  149. local_dir: Optional[str] = None,
  150. disable_tqdm: bool = False,
  151. ) -> Optional[str]: # pragma: no cover
  152. if not repo_type:
  153. repo_type = REPO_TYPE_MODEL
  154. if repo_type not in REPO_TYPE_SUPPORT:
  155. raise InvalidParameter('Invalid repo type: %s, only support: %s' %
  156. (repo_type, REPO_TYPE_SUPPORT))
  157. temporary_cache_dir, cache = create_temporary_directory_and_cache(
  158. repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type)
  159. # if local_files_only is `True` and the file already exists in cached_path
  160. # return the cached path
  161. if local_files_only:
  162. cached_file_path = cache.get_file_by_path(file_path)
  163. if cached_file_path is not None:
  164. logger.warning(
  165. "File exists in local cache, but we're not sure it's up to date"
  166. )
  167. return cached_file_path
  168. else:
  169. raise ValueError(
  170. 'Cannot find the requested files in the cached path and outgoing'
  171. ' traffic has been disabled. To enable look-ups and downloads'
  172. " online, set 'local_files_only' to False.")
  173. _api = HubApi()
  174. headers = {
  175. 'user-agent': ModelScopeConfig.get_user_agent(user_agent=user_agent, ),
  176. 'snapshot-identifier': str(uuid.uuid4()),
  177. }
  178. if INTRA_CLOUD_ACCELERATION == 'true':
  179. region_id: str = (
  180. os.getenv('INTRA_CLOUD_ACCELERATION_REGION')
  181. or _api._get_internal_acceleration_domain())
  182. if region_id:
  183. logger.info(
  184. f'Intra-cloud acceleration enabled for downloading from {repo_id}'
  185. )
  186. headers['x-aliyun-region-id'] = region_id
  187. if cookies is None:
  188. cookies = ModelScopeConfig.get_cookies()
  189. repo_files = []
  190. endpoint = _api.get_endpoint_for_read(repo_id=repo_id, repo_type=repo_type)
  191. file_to_download_meta = None
  192. if repo_type == REPO_TYPE_MODEL:
  193. revision = _api.get_valid_revision(
  194. repo_id, revision=revision, cookies=cookies, endpoint=endpoint)
  195. # we need to confirm the version is up-to-date
  196. # we need to get the file list to check if the latest version is cached, if so return, otherwise download
  197. repo_files = _api.get_model_files(
  198. model_id=repo_id,
  199. revision=revision,
  200. recursive=True,
  201. use_cookies=False if cookies is None else cookies,
  202. endpoint=endpoint)
  203. for repo_file in repo_files:
  204. if repo_file['Type'] == 'tree':
  205. continue
  206. if repo_file['Path'] == file_path:
  207. if cache.exists(repo_file):
  208. file_name = repo_file['Name']
  209. logger.debug(
  210. f'File {file_name} already in cache with identical hash, skip downloading!'
  211. )
  212. return cache.get_file_by_info(repo_file)
  213. else:
  214. file_to_download_meta = repo_file
  215. break
  216. elif repo_type == REPO_TYPE_DATASET:
  217. group_or_owner, name = model_id_to_group_owner_name(repo_id)
  218. if not revision:
  219. revision = DEFAULT_DATASET_REVISION
  220. page_number = 1
  221. page_size = 100
  222. while True:
  223. try:
  224. dataset_files = _api.get_dataset_files(
  225. repo_id=repo_id,
  226. revision=revision,
  227. root_path='/',
  228. recursive=True,
  229. page_number=page_number,
  230. page_size=page_size,
  231. endpoint=endpoint)
  232. except Exception as e:
  233. logger.error(
  234. f'Get dataset: {repo_id} file list failed, error: {e}')
  235. break
  236. is_exist = False
  237. for repo_file in dataset_files:
  238. if repo_file['Type'] == 'tree':
  239. continue
  240. if repo_file['Path'] == file_path:
  241. if cache.exists(repo_file):
  242. file_name = repo_file['Name']
  243. logger.debug(
  244. f'File {file_name} already in cache with identical hash, skip downloading!'
  245. )
  246. return cache.get_file_by_info(repo_file)
  247. else:
  248. file_to_download_meta = repo_file
  249. is_exist = True
  250. break
  251. if len(dataset_files) < page_size or is_exist:
  252. break
  253. page_number += 1
  254. if file_to_download_meta is None:
  255. raise NotExistError('The file path: %s not exist in: %s' %
  256. (file_path, repo_id))
  257. # we need to download again
  258. if repo_type == REPO_TYPE_MODEL:
  259. url_to_download = get_file_download_url(repo_id, file_path, revision,
  260. endpoint)
  261. elif repo_type == REPO_TYPE_DATASET:
  262. url_to_download = _api.get_dataset_file_url(
  263. file_name=file_to_download_meta['Path'],
  264. dataset_name=name,
  265. namespace=group_or_owner,
  266. revision=revision,
  267. endpoint=endpoint)
  268. else:
  269. raise ValueError(f'Invalid repo type {repo_type}')
  270. return download_file(url_to_download, file_to_download_meta,
  271. temporary_cache_dir, cache, headers, cookies)
  272. def move_legacy_cache_to_standard_dir(cache_dir: str, model_id: str):
  273. if cache_dir.endswith(os.path.sep):
  274. cache_dir = cache_dir.strip(os.path.sep)
  275. legacy_cache_root = os.path.dirname(cache_dir)
  276. base_name = os.path.basename(cache_dir)
  277. if base_name == 'datasets':
  278. # datasets will not be not affected
  279. return
  280. if not legacy_cache_root.endswith('hub'):
  281. # Two scenarios:
  282. # We have restructured ModelScope cache directory,
  283. # Scenery 1:
  284. # When MODELSCOPE_CACHE is not set, the default directory remains
  285. # the same at ~/.cache/modelscope/hub
  286. # Scenery 2:
  287. # When MODELSCOPE_CACHE is set, the cache directory is moved from
  288. # $MODELSCOPE_CACHE/hub to $MODELSCOPE_CACHE/. In this case,
  289. # we will be migrating the hub directory accordingly.
  290. legacy_cache_root = os.path.join(legacy_cache_root, 'hub')
  291. group_or_owner, name = model_id_to_group_owner_name(model_id)
  292. name = name.replace('.', '___')
  293. temporary_cache_dir = os.path.join(cache_dir, group_or_owner, name)
  294. legacy_cache_dir = os.path.join(legacy_cache_root, group_or_owner, name)
  295. if os.path.exists(
  296. legacy_cache_dir) and not os.path.exists(temporary_cache_dir):
  297. logger.info(
  298. f'Legacy cache dir exists: {legacy_cache_dir}, move to {temporary_cache_dir}'
  299. )
  300. try:
  301. shutil.move(legacy_cache_dir, temporary_cache_dir)
  302. except Exception: # noqa
  303. # Failed, skip
  304. pass
  305. def create_temporary_directory_and_cache(model_id: str,
  306. local_dir: str = None,
  307. cache_dir: str = None,
  308. repo_type: str = REPO_TYPE_MODEL):
  309. if repo_type == REPO_TYPE_MODEL:
  310. default_cache_root = get_model_cache_root()
  311. elif repo_type == REPO_TYPE_DATASET:
  312. default_cache_root = get_dataset_cache_root()
  313. else:
  314. raise ValueError(
  315. f'repo_type only support model and dataset, but now is : {repo_type}'
  316. )
  317. group_or_owner, name = model_id_to_group_owner_name(model_id)
  318. if local_dir is not None:
  319. temporary_cache_dir = os.path.join(local_dir, TEMPORARY_FOLDER_NAME)
  320. cache = ModelFileSystemCache(local_dir)
  321. else:
  322. if cache_dir is None:
  323. cache_dir = default_cache_root
  324. move_legacy_cache_to_standard_dir(cache_dir, model_id)
  325. if isinstance(cache_dir, Path):
  326. cache_dir = str(cache_dir)
  327. temporary_cache_dir = os.path.join(cache_dir, TEMPORARY_FOLDER_NAME,
  328. group_or_owner, name)
  329. name = name.replace('.', '___')
  330. cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
  331. os.makedirs(temporary_cache_dir, exist_ok=True)
  332. return temporary_cache_dir, cache
  333. def get_file_download_url(model_id: str,
  334. file_path: str,
  335. revision: str,
  336. endpoint: Optional[str] = None):
  337. """Format file download url according to `model_id`, `revision` and `file_path`.
  338. e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
  339. the resulted download url is: https://modelscope.cn/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
  340. Args:
  341. model_id (str): The model_id.
  342. file_path (str): File path
  343. revision (str): File revision.
  344. endpoint (str): The remote endpoint
  345. Returns:
  346. str: The file url.
  347. """
  348. file_path = urllib.parse.quote_plus(file_path)
  349. revision = urllib.parse.quote_plus(revision)
  350. download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
  351. if not endpoint:
  352. endpoint = get_endpoint()
  353. return download_url_template.format(
  354. endpoint=endpoint,
  355. model_id=model_id,
  356. revision=revision,
  357. file_path=file_path,
  358. )
  359. def download_part_with_retry(params):
  360. # unpack parameters
  361. model_file_path, progress_callbacks, start, end, url, file_name, cookies, headers = params
  362. get_headers = {} if headers is None else copy.deepcopy(headers)
  363. get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
  364. retry = Retry(
  365. total=API_FILE_DOWNLOAD_RETRY_TIMES,
  366. backoff_factor=1,
  367. allowed_methods=['GET'])
  368. part_file_name = model_file_path + '_%s_%s' % (start, end)
  369. while True:
  370. try:
  371. partial_length = 0
  372. if os.path.exists(
  373. part_file_name): # download partial, continue download
  374. with open(part_file_name, 'rb') as f:
  375. partial_length = f.seek(0, io.SEEK_END)
  376. for callback in progress_callbacks:
  377. callback.update(partial_length)
  378. download_start = start + partial_length
  379. if download_start > end:
  380. break # this part is download completed.
  381. get_headers['Range'] = 'bytes=%s-%s' % (download_start, end)
  382. with open(part_file_name, 'ab+') as f:
  383. r = requests.get(
  384. url,
  385. stream=True,
  386. headers=get_headers,
  387. cookies=cookies,
  388. timeout=API_FILE_DOWNLOAD_TIMEOUT)
  389. for chunk in r.iter_content(
  390. chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
  391. if chunk: # filter out keep-alive new chunks
  392. f.write(chunk)
  393. for callback in progress_callbacks:
  394. callback.update(len(chunk))
  395. break
  396. except (Exception) as e: # no matter what exception, we will retry.
  397. retry = retry.increment('GET', url, error=e)
  398. logger.warning('Downloading: %s failed, reason: %s will retry' %
  399. (model_file_path, e))
  400. retry.sleep()
  401. def parallel_download(url: str,
  402. local_dir: str,
  403. file_name: str,
  404. cookies: CookieJar,
  405. headers: Optional[Dict[str, str]] = None,
  406. file_size: int = None,
  407. disable_tqdm: bool = False,
  408. progress_callbacks: List[Type[ProgressCallback]] = None,
  409. endpoint: str = None):
  410. progress_callbacks = [] if progress_callbacks is None else progress_callbacks.copy(
  411. )
  412. if not disable_tqdm:
  413. progress_callbacks.append(TqdmCallback)
  414. progress_callbacks = [
  415. callback(file_name, file_size) for callback in progress_callbacks
  416. ]
  417. # create temp file
  418. PART_SIZE = 160 * 1024 * 1024 # every part is 160M
  419. tasks = []
  420. file_path = os.path.join(local_dir, file_name)
  421. os.makedirs(os.path.dirname(file_path), exist_ok=True)
  422. for idx in range(int(file_size / PART_SIZE)):
  423. start = idx * PART_SIZE
  424. end = (idx + 1) * PART_SIZE - 1
  425. tasks.append((file_path, progress_callbacks, start, end, url,
  426. file_name, cookies, headers))
  427. if end + 1 < file_size:
  428. tasks.append((file_path, progress_callbacks, end + 1, file_size - 1,
  429. url, file_name, cookies, headers))
  430. parallels = min(MODELSCOPE_DOWNLOAD_PARALLELS, 16)
  431. # download every part
  432. with ThreadPoolExecutor(
  433. max_workers=parallels, thread_name_prefix='download') as executor:
  434. list(executor.map(download_part_with_retry, tasks))
  435. for callback in progress_callbacks:
  436. callback.end()
  437. # merge parts.
  438. hash_sha256 = hashlib.sha256()
  439. with open(os.path.join(local_dir, file_name), 'wb') as output_file:
  440. for task in tasks:
  441. part_file_name = task[0] + '_%s_%s' % (task[2], task[3])
  442. with open(part_file_name, 'rb') as part_file:
  443. while True:
  444. chunk = part_file.read(16 * API_FILE_DOWNLOAD_CHUNK_SIZE)
  445. if not chunk:
  446. break
  447. output_file.write(chunk)
  448. hash_sha256.update(chunk)
  449. os.remove(part_file_name)
  450. return hash_sha256.hexdigest()
  451. def http_get_model_file(
  452. url: str,
  453. local_dir: str,
  454. file_name: str,
  455. file_size: int,
  456. cookies: CookieJar,
  457. headers: Optional[Dict[str, str]] = None,
  458. disable_tqdm: bool = False,
  459. progress_callbacks: List[Type[ProgressCallback]] = None,
  460. ):
  461. """Download remote file, will retry 5 times before giving up on errors.
  462. Args:
  463. url(str):
  464. actual download url of the file
  465. local_dir(str):
  466. local directory where the downloaded file stores
  467. file_name(str):
  468. name of the file stored in `local_dir`
  469. file_size(int):
  470. The file size.
  471. cookies(CookieJar):
  472. cookies used to authentication the user, which is used for downloading private repos
  473. headers(Dict[str, str], optional):
  474. http headers to carry necessary info when requesting the remote file
  475. disable_tqdm(bool, optional): Disable the progress bar with tqdm.
  476. progress_callbacks(List[Type[ProgressCallback]], optional):
  477. progress callbacks to track the download progress.
  478. Raises:
  479. FileDownloadError: File download failed.
  480. """
  481. progress_callbacks = [] if progress_callbacks is None else progress_callbacks.copy(
  482. )
  483. if not disable_tqdm:
  484. progress_callbacks.append(TqdmCallback)
  485. progress_callbacks = [
  486. callback(file_name, file_size) for callback in progress_callbacks
  487. ]
  488. get_headers = {} if headers is None else copy.deepcopy(headers)
  489. get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
  490. temp_file_path = os.path.join(local_dir, file_name)
  491. os.makedirs(os.path.dirname(temp_file_path), exist_ok=True)
  492. logger.debug('downloading %s to %s', url, temp_file_path)
  493. # retry sleep 0.5s, 1s, 2s, 4s
  494. has_retry = False
  495. hash_sha256 = hashlib.sha256()
  496. retry = Retry(
  497. total=API_FILE_DOWNLOAD_RETRY_TIMES,
  498. backoff_factor=1,
  499. allowed_methods=['GET'])
  500. while True:
  501. try:
  502. if file_size == 0:
  503. # Avoid empty file server request
  504. with open(temp_file_path, 'w+'):
  505. for callback in progress_callbacks:
  506. callback.update(1)
  507. break
  508. # Determine the length of any existing partial download
  509. partial_length = 0
  510. # download partial, continue download
  511. if os.path.exists(temp_file_path):
  512. # resuming from interrupted download is also considered as retry
  513. has_retry = True
  514. with open(temp_file_path, 'rb') as f:
  515. partial_length = f.seek(0, io.SEEK_END)
  516. for callback in progress_callbacks:
  517. callback.update(partial_length)
  518. # Check if download is complete
  519. if partial_length >= file_size:
  520. break
  521. # closed range[], from 0.
  522. get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
  523. file_size - 1)
  524. with open(temp_file_path, 'ab+') as f:
  525. r = requests.get(
  526. url,
  527. stream=True,
  528. headers=get_headers,
  529. cookies=cookies,
  530. timeout=API_FILE_DOWNLOAD_TIMEOUT)
  531. r.raise_for_status()
  532. for chunk in r.iter_content(
  533. chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
  534. if chunk: # filter out keep-alive new chunks
  535. for callback in progress_callbacks:
  536. callback.update(len(chunk))
  537. f.write(chunk)
  538. # hash would be discarded in retry case anyway
  539. if not has_retry:
  540. hash_sha256.update(chunk)
  541. break
  542. except Exception as e: # no matter what happen, we will retry.
  543. has_retry = True
  544. retry = retry.increment('GET', url, error=e)
  545. retry.sleep()
  546. for callback in progress_callbacks:
  547. callback.end()
  548. # if anything went wrong, we would discard the real-time computed hash and return None
  549. return None if has_retry else hash_sha256.hexdigest()
  550. def http_get_file(
  551. url: str,
  552. local_dir: str,
  553. file_name: str,
  554. cookies: CookieJar,
  555. headers: Optional[Dict[str, str]] = None,
  556. ):
  557. """Download remote file, will retry 5 times before giving up on errors.
  558. Args:
  559. url(str):
  560. actual download url of the file
  561. local_dir(str):
  562. local directory where the downloaded file stores
  563. file_name(str):
  564. name of the file stored in `local_dir`
  565. cookies(CookieJar):
  566. cookies used to authentication the user, which is used for downloading private repos
  567. headers(Dict[str, str], optional):
  568. http headers to carry necessary info when requesting the remote file
  569. Raises:
  570. FileDownloadError: File download failed.
  571. """
  572. total = -1
  573. temp_file_manager = partial(
  574. tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
  575. get_headers = {} if headers is None else copy.deepcopy(headers)
  576. get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
  577. with temp_file_manager() as temp_file:
  578. logger.debug('downloading %s to %s', url, temp_file.name)
  579. # retry sleep 0.5s, 1s, 2s, 4s
  580. retry = Retry(
  581. total=API_FILE_DOWNLOAD_RETRY_TIMES,
  582. backoff_factor=1,
  583. allowed_methods=['GET'])
  584. while True:
  585. try:
  586. downloaded_size = temp_file.tell()
  587. get_headers['Range'] = 'bytes=%d-' % downloaded_size
  588. r = requests.get(
  589. url,
  590. stream=True,
  591. headers=get_headers,
  592. cookies=cookies,
  593. timeout=API_FILE_DOWNLOAD_TIMEOUT)
  594. r.raise_for_status()
  595. content_length = r.headers.get('Content-Length')
  596. total = int(
  597. content_length) if content_length is not None else None
  598. progress = tqdm(
  599. unit='B',
  600. unit_scale=True,
  601. unit_divisor=1024,
  602. total=total,
  603. initial=downloaded_size,
  604. desc='Downloading [' + file_name + ']',
  605. )
  606. for chunk in r.iter_content(
  607. chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
  608. if chunk: # filter out keep-alive new chunks
  609. progress.update(len(chunk))
  610. temp_file.write(chunk)
  611. progress.close()
  612. break
  613. except (Exception) as e: # no matter what happen, we will retry.
  614. retry = retry.increment('GET', url, error=e)
  615. retry.sleep()
  616. logger.debug('storing %s in cache at %s', url, local_dir)
  617. downloaded_length = os.path.getsize(temp_file.name)
  618. if total != downloaded_length:
  619. os.remove(temp_file.name)
  620. msg = 'File %s download incomplete, content_length: %s but the \
  621. file downloaded length: %s, please download again' % (
  622. file_name, total, downloaded_length)
  623. logger.error(msg)
  624. raise FileDownloadError(msg)
  625. os.replace(temp_file.name, os.path.join(local_dir, file_name))
  626. def download_file(
  627. url,
  628. file_meta,
  629. temporary_cache_dir,
  630. cache,
  631. headers,
  632. cookies,
  633. disable_tqdm=False,
  634. progress_callbacks: List[Type[ProgressCallback]] = None,
  635. ):
  636. if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
  637. 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
  638. file_digest = parallel_download(
  639. url,
  640. temporary_cache_dir,
  641. file_meta['Path'],
  642. headers=headers,
  643. cookies=None if cookies is None else cookies.get_dict(),
  644. file_size=file_meta['Size'],
  645. disable_tqdm=disable_tqdm,
  646. progress_callbacks=progress_callbacks,
  647. )
  648. else:
  649. file_digest = http_get_model_file(
  650. url,
  651. temporary_cache_dir,
  652. file_meta['Path'],
  653. file_size=file_meta['Size'],
  654. headers=headers,
  655. cookies=cookies,
  656. disable_tqdm=disable_tqdm,
  657. progress_callbacks=progress_callbacks,
  658. )
  659. # check file integrity
  660. temp_file = os.path.join(temporary_cache_dir, file_meta['Path'])
  661. if FILE_HASH in file_meta:
  662. expected_hash = file_meta[FILE_HASH]
  663. # if a real-time hash has been computed
  664. if file_digest is not None:
  665. # if real-time hash mismatched, try to compute it again
  666. if file_digest != expected_hash:
  667. print(
  668. 'Mismatched real-time digest found, falling back to lump-sum hash computation'
  669. )
  670. file_integrity_validation(temp_file, expected_hash)
  671. else:
  672. file_integrity_validation(temp_file, expected_hash)
  673. # put file into to cache
  674. return cache.put_file(file_meta, temp_file)