file_download.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import copy
  3. import hashlib
  4. import io
  5. import os
  6. import shutil
  7. import tempfile
  8. import urllib
  9. import uuid
  10. from concurrent.futures import ThreadPoolExecutor
  11. from functools import partial
  12. from http.cookiejar import CookieJar
  13. from pathlib import Path
  14. from typing import Dict, List, Optional, Type, Union
  15. import requests
  16. from requests.adapters import Retry
  17. from tqdm.auto import tqdm
  18. from modelscope.hub.api import HubApi, ModelScopeConfig
  19. from modelscope.hub.constants import (
  20. API_FILE_DOWNLOAD_CHUNK_SIZE, API_FILE_DOWNLOAD_RETRY_TIMES,
  21. API_FILE_DOWNLOAD_TIMEOUT, FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS,
  22. MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB, TEMPORARY_FOLDER_NAME)
  23. from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
  24. DEFAULT_MODEL_REVISION,
  25. INTRA_CLOUD_ACCELERATION,
  26. REPO_TYPE_DATASET, REPO_TYPE_MODEL,
  27. REPO_TYPE_SUPPORT)
  28. from modelscope.utils.file_utils import (get_dataset_cache_root,
  29. get_model_cache_root)
  30. from modelscope.utils.logger import get_logger
  31. from .callback import ProgressCallback, TqdmCallback
  32. from .errors import FileDownloadError, InvalidParameter, NotExistError
  33. from .utils.caching import ModelFileSystemCache
  34. from .utils.utils import (file_integrity_validation, get_endpoint,
  35. model_id_to_group_owner_name)
  36. logger = get_logger()
  37. def model_file_download(
  38. model_id: str,
  39. file_path: str,
  40. revision: Optional[str] = DEFAULT_MODEL_REVISION,
  41. cache_dir: Optional[str] = None,
  42. user_agent: Union[Dict, str, None] = None,
  43. local_files_only: Optional[bool] = False,
  44. cookies: Optional[CookieJar] = None,
  45. local_dir: Optional[str] = None,
  46. token: Optional[str] = None,
  47. ) -> Optional[str]: # pragma: no cover
  48. """Download from a given URL and cache it if it's not already present in the local cache.
  49. Given a URL, this function looks for the corresponding file in the local
  50. cache. If it's not there, download it. Then return the path to the cached
  51. file.
  52. Args:
  53. model_id (str): The model to whom the file to be downloaded belongs.
  54. file_path(str): Path of the file to be downloaded, relative to the root of model repo.
  55. revision(str, optional): revision of the model file to be downloaded.
  56. Can be any of a branch, tag or commit hash.
  57. cache_dir (str, Path, optional): Path to the folder where cached files are stored.
  58. user_agent (dict, str, optional): The user-agent info in the form of a dictionary or a string.
  59. local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
  60. local cached file if it exists. if `False`, download the file anyway even it exists.
  61. cookies (CookieJar, optional): The cookie of download request.
  62. local_dir (str, optional): Specific local directory path to which the file will be downloaded.
  63. token (str, optional): The user token.
  64. Returns:
  65. string: string of local file or if networking is off, last version of
  66. file cached on disk.
  67. Raises:
  68. NotExistError: The file is not exist.
  69. ValueError: The request parameter error.
  70. Note:
  71. Raises the following errors:
  72. - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
  73. if `use_auth_token=True` and the token cannot be found.
  74. - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError)
  75. if ETag cannot be determined.
  76. - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  77. if some parameter value is invalid
  78. """
  79. return _repo_file_download(
  80. model_id,
  81. file_path,
  82. repo_type=REPO_TYPE_MODEL,
  83. revision=revision,
  84. cache_dir=cache_dir,
  85. user_agent=user_agent,
  86. local_files_only=local_files_only,
  87. cookies=cookies,
  88. local_dir=local_dir,
  89. token=token)
  90. def dataset_file_download(
  91. dataset_id: str,
  92. file_path: str,
  93. revision: Optional[str] = DEFAULT_DATASET_REVISION,
  94. cache_dir: Union[str, Path, None] = None,
  95. local_dir: Optional[str] = None,
  96. user_agent: Optional[Union[Dict, str]] = None,
  97. local_files_only: Optional[bool] = False,
  98. cookies: Optional[CookieJar] = None,
  99. token: Optional[str] = None,
  100. ) -> str:
  101. """Download raw files of a dataset.
  102. Downloads all files at the specified revision. This
  103. is useful when you want all files from a dataset, because you don't know which
  104. ones you will need a priori. All files are nested inside a folder in order
  105. to keep their actual filename relative to that folder.
  106. An alternative would be to just clone a dataset but this would require that the
  107. user always has git and git-lfs installed, and properly configured.
  108. Args:
  109. dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
  110. file_path (str): The relative path of the file to download.
  111. revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
  112. commit hash. NOTE: currently only branch and tag name is supported
  113. cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset file will
  114. be save as cache_dir/dataset_id/THE_DATASET_FILES.
  115. local_dir (str, optional): Specific local directory path to which the file will be downloaded.
  116. user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
  117. local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
  118. local cached file if it exists.
  119. cookies (CookieJar, optional): The cookie of the request, default None.
  120. token (str, optional): The user token.
  121. Raises:
  122. ValueError: the value details.
  123. Returns:
  124. str: Local folder path (string) of repo snapshot
  125. Note:
  126. Raises the following errors:
  127. - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
  128. if `use_auth_token=True` and the token cannot be found.
  129. - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
  130. ETag cannot be determined.
  131. - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  132. if some parameter value is invalid
  133. """
  134. return _repo_file_download(
  135. dataset_id,
  136. file_path,
  137. repo_type=REPO_TYPE_DATASET,
  138. revision=revision,
  139. cache_dir=cache_dir,
  140. user_agent=user_agent,
  141. local_files_only=local_files_only,
  142. cookies=cookies,
  143. local_dir=local_dir,
  144. token=token)
  145. def _repo_file_download(
  146. repo_id: str,
  147. file_path: str,
  148. *,
  149. repo_type: str = None,
  150. revision: Optional[str] = DEFAULT_MODEL_REVISION,
  151. cache_dir: Optional[str] = None,
  152. user_agent: Union[Dict, str, None] = None,
  153. local_files_only: Optional[bool] = False,
  154. cookies: Optional[CookieJar] = None,
  155. local_dir: Optional[str] = None,
  156. disable_tqdm: bool = False,
  157. token: Optional[str] = None,
  158. ) -> Optional[str]: # pragma: no cover
  159. if not repo_type:
  160. repo_type = REPO_TYPE_MODEL
  161. if repo_type not in REPO_TYPE_SUPPORT:
  162. raise InvalidParameter('Invalid repo type: %s, only support: %s' %
  163. (repo_type, REPO_TYPE_SUPPORT))
  164. temporary_cache_dir, cache = create_temporary_directory_and_cache(
  165. repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type)
  166. # if local_files_only is `True` and the file already exists in cached_path
  167. # return the cached path
  168. if local_files_only:
  169. cached_file_path = cache.get_file_by_path(file_path)
  170. if cached_file_path is not None:
  171. logger.warning(
  172. "File exists in local cache, but we're not sure it's up to date"
  173. )
  174. return cached_file_path
  175. else:
  176. raise ValueError(
  177. 'Cannot find the requested files in the cached path and outgoing'
  178. ' traffic has been disabled. To enable look-ups and downloads'
  179. " online, set 'local_files_only' to False.")
  180. _api = HubApi(token=token)
  181. headers = {
  182. 'user-agent': ModelScopeConfig.get_user_agent(user_agent=user_agent, ),
  183. 'snapshot-identifier': str(uuid.uuid4()),
  184. }
  185. if INTRA_CLOUD_ACCELERATION == 'true':
  186. region_id: str = (
  187. os.getenv('INTRA_CLOUD_ACCELERATION_REGION')
  188. or _api._get_internal_acceleration_domain())
  189. if region_id:
  190. logger.info(
  191. f'Intra-cloud acceleration enabled for downloading from {repo_id}'
  192. )
  193. headers['x-aliyun-region-id'] = region_id
  194. if cookies is None:
  195. cookies = _api.get_cookies()
  196. repo_files = []
  197. endpoint = _api.get_endpoint_for_read(
  198. repo_id=repo_id, repo_type=repo_type, token=token)
  199. file_to_download_meta = None
  200. if repo_type == REPO_TYPE_MODEL:
  201. revision = _api.get_valid_revision(
  202. repo_id, revision=revision, cookies=cookies, endpoint=endpoint)
  203. # we need to confirm the version is up-to-date
  204. # we need to get the file list to check if the latest version is cached, if so return, otherwise download
  205. repo_files = _api.get_model_files(
  206. model_id=repo_id,
  207. revision=revision,
  208. recursive=True,
  209. use_cookies=False if cookies is None else cookies,
  210. endpoint=endpoint)
  211. for repo_file in repo_files:
  212. if repo_file['Type'] == 'tree':
  213. continue
  214. if repo_file['Path'] == file_path:
  215. if cache.exists(repo_file):
  216. file_name = repo_file['Name']
  217. logger.debug(
  218. f'File {file_name} already in cache with identical hash, skip downloading!'
  219. )
  220. return cache.get_file_by_info(repo_file)
  221. else:
  222. file_to_download_meta = repo_file
  223. break
  224. elif repo_type == REPO_TYPE_DATASET:
  225. group_or_owner, name = model_id_to_group_owner_name(repo_id)
  226. if not revision:
  227. revision = DEFAULT_DATASET_REVISION
  228. page_number = 1
  229. page_size = 100
  230. while True:
  231. try:
  232. dataset_files = _api.get_dataset_files(
  233. repo_id=repo_id,
  234. revision=revision,
  235. root_path='/',
  236. recursive=True,
  237. page_number=page_number,
  238. page_size=page_size,
  239. endpoint=endpoint,
  240. token=token)
  241. except Exception as e:
  242. logger.error(
  243. f'Get dataset: {repo_id} file list failed, error: {e}')
  244. break
  245. is_exist = False
  246. for repo_file in dataset_files:
  247. if repo_file['Type'] == 'tree':
  248. continue
  249. if repo_file['Path'] == file_path:
  250. if cache.exists(repo_file):
  251. file_name = repo_file['Name']
  252. logger.debug(
  253. f'File {file_name} already in cache with identical hash, skip downloading!'
  254. )
  255. return cache.get_file_by_info(repo_file)
  256. else:
  257. file_to_download_meta = repo_file
  258. is_exist = True
  259. break
  260. if len(dataset_files) < page_size or is_exist:
  261. break
  262. page_number += 1
  263. if file_to_download_meta is None:
  264. raise NotExistError('The file path: %s not exist in: %s' %
  265. (file_path, repo_id))
  266. # we need to download again
  267. if repo_type == REPO_TYPE_MODEL:
  268. url_to_download = get_file_download_url(repo_id, file_path, revision,
  269. endpoint)
  270. elif repo_type == REPO_TYPE_DATASET:
  271. url_to_download = _api.get_dataset_file_url(
  272. file_name=file_to_download_meta['Path'],
  273. dataset_name=name,
  274. namespace=group_or_owner,
  275. revision=revision,
  276. endpoint=endpoint)
  277. else:
  278. raise ValueError(f'Invalid repo type {repo_type}')
  279. return download_file(url_to_download, file_to_download_meta,
  280. temporary_cache_dir, cache, headers, cookies)
  281. def move_legacy_cache_to_standard_dir(cache_dir: str, model_id: str):
  282. if cache_dir.endswith(os.path.sep):
  283. cache_dir = cache_dir.strip(os.path.sep)
  284. legacy_cache_root = os.path.dirname(cache_dir)
  285. base_name = os.path.basename(cache_dir)
  286. if base_name == 'datasets':
  287. # datasets will not be not affected
  288. return
  289. if not legacy_cache_root.endswith('hub'):
  290. # Two scenarios:
  291. # We have restructured ModelScope cache directory,
  292. # Scenery 1:
  293. # When MODELSCOPE_CACHE is not set, the default directory remains
  294. # the same at ~/.cache/modelscope/hub
  295. # Scenery 2:
  296. # When MODELSCOPE_CACHE is set, the cache directory is moved from
  297. # $MODELSCOPE_CACHE/hub to $MODELSCOPE_CACHE/. In this case,
  298. # we will be migrating the hub directory accordingly.
  299. legacy_cache_root = os.path.join(legacy_cache_root, 'hub')
  300. group_or_owner, name = model_id_to_group_owner_name(model_id)
  301. name = name.replace('.', '___')
  302. temporary_cache_dir = os.path.join(cache_dir, group_or_owner, name)
  303. legacy_cache_dir = os.path.join(legacy_cache_root, group_or_owner, name)
  304. if os.path.exists(
  305. legacy_cache_dir) and not os.path.exists(temporary_cache_dir):
  306. logger.info(
  307. f'Legacy cache dir exists: {legacy_cache_dir}, move to {temporary_cache_dir}'
  308. )
  309. try:
  310. shutil.move(legacy_cache_dir, temporary_cache_dir)
  311. except Exception: # noqa
  312. # Failed, skip
  313. pass
  314. def create_temporary_directory_and_cache(model_id: str,
  315. local_dir: str = None,
  316. cache_dir: str = None,
  317. repo_type: str = REPO_TYPE_MODEL):
  318. if repo_type == REPO_TYPE_MODEL:
  319. default_cache_root = get_model_cache_root()
  320. elif repo_type == REPO_TYPE_DATASET:
  321. default_cache_root = get_dataset_cache_root()
  322. else:
  323. raise ValueError(
  324. f'repo_type only support model and dataset, but now is : {repo_type}'
  325. )
  326. group_or_owner, name = model_id_to_group_owner_name(model_id)
  327. if local_dir is not None:
  328. temporary_cache_dir = os.path.join(local_dir, TEMPORARY_FOLDER_NAME)
  329. cache = ModelFileSystemCache(local_dir)
  330. else:
  331. if cache_dir is None:
  332. cache_dir = default_cache_root
  333. move_legacy_cache_to_standard_dir(cache_dir, model_id)
  334. if isinstance(cache_dir, Path):
  335. cache_dir = str(cache_dir)
  336. temporary_cache_dir = os.path.join(cache_dir, TEMPORARY_FOLDER_NAME,
  337. group_or_owner, name)
  338. name = name.replace('.', '___')
  339. cache = ModelFileSystemCache(cache_dir, group_or_owner, name)
  340. os.makedirs(temporary_cache_dir, exist_ok=True)
  341. return temporary_cache_dir, cache
  342. def get_file_download_url(model_id: str,
  343. file_path: str,
  344. revision: str,
  345. endpoint: Optional[str] = None):
  346. """Format file download url according to `model_id`, `revision` and `file_path`.
  347. e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
  348. the resulted download url is: https://modelscope.cn/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
  349. Args:
  350. model_id (str): The model_id.
  351. file_path (str): File path
  352. revision (str): File revision.
  353. endpoint (str): The remote endpoint
  354. Returns:
  355. str: The file url.
  356. """
  357. file_path = urllib.parse.quote_plus(file_path)
  358. revision = urllib.parse.quote_plus(revision)
  359. download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
  360. if not endpoint:
  361. endpoint = get_endpoint()
  362. return download_url_template.format(
  363. endpoint=endpoint,
  364. model_id=model_id,
  365. revision=revision,
  366. file_path=file_path,
  367. )
  368. def download_part_with_retry(params):
  369. # unpack parameters
  370. model_file_path, progress_callbacks, start, end, url, file_name, cookies, headers = params
  371. get_headers = {} if headers is None else copy.deepcopy(headers)
  372. get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
  373. retry = Retry(
  374. total=API_FILE_DOWNLOAD_RETRY_TIMES,
  375. backoff_factor=1,
  376. allowed_methods=['GET'])
  377. part_file_name = model_file_path + '_%s_%s' % (start, end)
  378. while True:
  379. try:
  380. partial_length = 0
  381. if os.path.exists(
  382. part_file_name): # download partial, continue download
  383. with open(part_file_name, 'rb') as f:
  384. partial_length = f.seek(0, io.SEEK_END)
  385. for callback in progress_callbacks:
  386. callback.update(partial_length)
  387. download_start = start + partial_length
  388. if download_start > end:
  389. break # this part is download completed.
  390. get_headers['Range'] = 'bytes=%s-%s' % (download_start, end)
  391. with open(part_file_name, 'ab+') as f:
  392. r = requests.get(
  393. url,
  394. stream=True,
  395. headers=get_headers,
  396. cookies=cookies,
  397. timeout=API_FILE_DOWNLOAD_TIMEOUT)
  398. for chunk in r.iter_content(
  399. chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
  400. if chunk: # filter out keep-alive new chunks
  401. f.write(chunk)
  402. for callback in progress_callbacks:
  403. callback.update(len(chunk))
  404. break
  405. except (Exception) as e: # no matter what exception, we will retry.
  406. retry = retry.increment('GET', url, error=e)
  407. logger.warning('Downloading: %s failed, reason: %s will retry' %
  408. (model_file_path, e))
  409. retry.sleep()
  410. def parallel_download(url: str,
  411. local_dir: str,
  412. file_name: str,
  413. cookies: CookieJar,
  414. headers: Optional[Dict[str, str]] = None,
  415. file_size: int = None,
  416. disable_tqdm: bool = False,
  417. progress_callbacks: List[Type[ProgressCallback]] = None,
  418. endpoint: str = None):
  419. progress_callbacks = [] if progress_callbacks is None else progress_callbacks.copy(
  420. )
  421. if not disable_tqdm:
  422. progress_callbacks.append(TqdmCallback)
  423. progress_callbacks = [
  424. callback(file_name, file_size) for callback in progress_callbacks
  425. ]
  426. # create temp file
  427. PART_SIZE = 160 * 1024 * 1024 # every part is 160M
  428. tasks = []
  429. file_path = os.path.join(local_dir, file_name)
  430. os.makedirs(os.path.dirname(file_path), exist_ok=True)
  431. for idx in range(int(file_size / PART_SIZE)):
  432. start = idx * PART_SIZE
  433. end = (idx + 1) * PART_SIZE - 1
  434. tasks.append((file_path, progress_callbacks, start, end, url,
  435. file_name, cookies, headers))
  436. if end + 1 < file_size:
  437. tasks.append((file_path, progress_callbacks, end + 1, file_size - 1,
  438. url, file_name, cookies, headers))
  439. parallels = min(MODELSCOPE_DOWNLOAD_PARALLELS, 16)
  440. # download every part
  441. with ThreadPoolExecutor(
  442. max_workers=parallels, thread_name_prefix='download') as executor:
  443. list(executor.map(download_part_with_retry, tasks))
  444. for callback in progress_callbacks:
  445. callback.end()
  446. # merge parts.
  447. hash_sha256 = hashlib.sha256()
  448. with open(os.path.join(local_dir, file_name), 'wb') as output_file:
  449. for task in tasks:
  450. part_file_name = task[0] + '_%s_%s' % (task[2], task[3])
  451. with open(part_file_name, 'rb') as part_file:
  452. while True:
  453. chunk = part_file.read(16 * API_FILE_DOWNLOAD_CHUNK_SIZE)
  454. if not chunk:
  455. break
  456. output_file.write(chunk)
  457. hash_sha256.update(chunk)
  458. os.remove(part_file_name)
  459. return hash_sha256.hexdigest()
  460. def http_get_model_file(
  461. url: str,
  462. local_dir: str,
  463. file_name: str,
  464. file_size: int,
  465. cookies: CookieJar,
  466. headers: Optional[Dict[str, str]] = None,
  467. disable_tqdm: bool = False,
  468. progress_callbacks: List[Type[ProgressCallback]] = None,
  469. ):
  470. """Download remote file, will retry 5 times before giving up on errors.
  471. Args:
  472. url(str):
  473. actual download url of the file
  474. local_dir(str):
  475. local directory where the downloaded file stores
  476. file_name(str):
  477. name of the file stored in `local_dir`
  478. file_size(int):
  479. The file size.
  480. cookies(CookieJar):
  481. cookies used to authentication the user, which is used for downloading private repos
  482. headers(Dict[str, str], optional):
  483. http headers to carry necessary info when requesting the remote file
  484. disable_tqdm(bool, optional): Disable the progress bar with tqdm.
  485. progress_callbacks(List[Type[ProgressCallback]], optional):
  486. progress callbacks to track the download progress.
  487. Raises:
  488. FileDownloadError: File download failed.
  489. """
  490. progress_callbacks = [] if progress_callbacks is None else progress_callbacks.copy(
  491. )
  492. if not disable_tqdm:
  493. progress_callbacks.append(TqdmCallback)
  494. progress_callbacks = [
  495. callback(file_name, file_size) for callback in progress_callbacks
  496. ]
  497. get_headers = {} if headers is None else copy.deepcopy(headers)
  498. get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
  499. temp_file_path = os.path.join(local_dir, file_name)
  500. os.makedirs(os.path.dirname(temp_file_path), exist_ok=True)
  501. logger.debug('downloading %s to %s', url, temp_file_path)
  502. # retry sleep 0.5s, 1s, 2s, 4s
  503. has_retry = False
  504. hash_sha256 = hashlib.sha256()
  505. retry = Retry(
  506. total=API_FILE_DOWNLOAD_RETRY_TIMES,
  507. backoff_factor=1,
  508. allowed_methods=['GET'])
  509. while True:
  510. try:
  511. if file_size == 0:
  512. # Avoid empty file server request
  513. with open(temp_file_path, 'w+'):
  514. for callback in progress_callbacks:
  515. callback.update(1)
  516. break
  517. # Determine the length of any existing partial download
  518. partial_length = 0
  519. # download partial, continue download
  520. if os.path.exists(temp_file_path):
  521. # resuming from interrupted download is also considered as retry
  522. has_retry = True
  523. with open(temp_file_path, 'rb') as f:
  524. partial_length = f.seek(0, io.SEEK_END)
  525. for callback in progress_callbacks:
  526. callback.update(partial_length)
  527. # Check if download is complete
  528. if partial_length >= file_size:
  529. break
  530. # closed range[], from 0.
  531. get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
  532. file_size - 1)
  533. with open(temp_file_path, 'ab+') as f:
  534. r = requests.get(
  535. url,
  536. stream=True,
  537. headers=get_headers,
  538. cookies=cookies,
  539. timeout=API_FILE_DOWNLOAD_TIMEOUT)
  540. r.raise_for_status()
  541. for chunk in r.iter_content(
  542. chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
  543. if chunk: # filter out keep-alive new chunks
  544. for callback in progress_callbacks:
  545. callback.update(len(chunk))
  546. f.write(chunk)
  547. # hash would be discarded in retry case anyway
  548. if not has_retry:
  549. hash_sha256.update(chunk)
  550. break
  551. except Exception as e: # no matter what happen, we will retry.
  552. has_retry = True
  553. retry = retry.increment('GET', url, error=e)
  554. retry.sleep()
  555. for callback in progress_callbacks:
  556. callback.end()
  557. # if anything went wrong, we would discard the real-time computed hash and return None
  558. return None if has_retry else hash_sha256.hexdigest()
  559. def http_get_file(
  560. url: str,
  561. local_dir: str,
  562. file_name: str,
  563. cookies: CookieJar,
  564. headers: Optional[Dict[str, str]] = None,
  565. ):
  566. """Download remote file, will retry 5 times before giving up on errors.
  567. Args:
  568. url(str):
  569. actual download url of the file
  570. local_dir(str):
  571. local directory where the downloaded file stores
  572. file_name(str):
  573. name of the file stored in `local_dir`
  574. cookies(CookieJar):
  575. cookies used to authentication the user, which is used for downloading private repos
  576. headers(Dict[str, str], optional):
  577. http headers to carry necessary info when requesting the remote file
  578. Raises:
  579. FileDownloadError: File download failed.
  580. """
  581. total = -1
  582. temp_file_manager = partial(
  583. tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
  584. get_headers = {} if headers is None else copy.deepcopy(headers)
  585. get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
  586. with temp_file_manager() as temp_file:
  587. logger.debug('downloading %s to %s', url, temp_file.name)
  588. # retry sleep 0.5s, 1s, 2s, 4s
  589. retry = Retry(
  590. total=API_FILE_DOWNLOAD_RETRY_TIMES,
  591. backoff_factor=1,
  592. allowed_methods=['GET'])
  593. while True:
  594. try:
  595. downloaded_size = temp_file.tell()
  596. get_headers['Range'] = 'bytes=%d-' % downloaded_size
  597. r = requests.get(
  598. url,
  599. stream=True,
  600. headers=get_headers,
  601. cookies=cookies,
  602. timeout=API_FILE_DOWNLOAD_TIMEOUT)
  603. r.raise_for_status()
  604. content_length = r.headers.get('Content-Length')
  605. total = int(
  606. content_length) if content_length is not None else None
  607. progress = tqdm(
  608. unit='B',
  609. unit_scale=True,
  610. unit_divisor=1024,
  611. total=total,
  612. initial=downloaded_size,
  613. desc='Downloading [' + file_name + ']',
  614. )
  615. for chunk in r.iter_content(
  616. chunk_size=API_FILE_DOWNLOAD_CHUNK_SIZE):
  617. if chunk: # filter out keep-alive new chunks
  618. progress.update(len(chunk))
  619. temp_file.write(chunk)
  620. progress.close()
  621. break
  622. except (Exception) as e: # no matter what happen, we will retry.
  623. retry = retry.increment('GET', url, error=e)
  624. retry.sleep()
  625. logger.debug('storing %s in cache at %s', url, local_dir)
  626. downloaded_length = os.path.getsize(temp_file.name)
  627. if total != downloaded_length:
  628. os.remove(temp_file.name)
  629. msg = 'File %s download incomplete, content_length: %s but the \
  630. file downloaded length: %s, please download again' % (
  631. file_name, total, downloaded_length)
  632. logger.error(msg)
  633. raise FileDownloadError(msg)
  634. os.replace(temp_file.name, os.path.join(local_dir, file_name))
  635. def download_file(
  636. url,
  637. file_meta,
  638. temporary_cache_dir,
  639. cache,
  640. headers,
  641. cookies,
  642. disable_tqdm=False,
  643. progress_callbacks: List[Type[ProgressCallback]] = None,
  644. ):
  645. if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < file_meta[
  646. 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: # parallel download large file.
  647. file_digest = parallel_download(
  648. url,
  649. temporary_cache_dir,
  650. file_meta['Path'],
  651. headers=headers,
  652. cookies=None if cookies is None else cookies.get_dict(),
  653. file_size=file_meta['Size'],
  654. disable_tqdm=disable_tqdm,
  655. progress_callbacks=progress_callbacks,
  656. )
  657. else:
  658. file_digest = http_get_model_file(
  659. url,
  660. temporary_cache_dir,
  661. file_meta['Path'],
  662. file_size=file_meta['Size'],
  663. headers=headers,
  664. cookies=cookies,
  665. disable_tqdm=disable_tqdm,
  666. progress_callbacks=progress_callbacks,
  667. )
  668. # check file integrity
  669. temp_file = os.path.join(temporary_cache_dir, file_meta['Path'])
  670. if FILE_HASH in file_meta:
  671. expected_hash = file_meta[FILE_HASH]
  672. # if a real-time hash has been computed
  673. if file_digest is not None:
  674. # if real-time hash mismatched, try to compute it again
  675. if file_digest != expected_hash:
  676. print(
  677. 'Mismatched real-time digest found, falling back to lump-sum hash computation'
  678. )
  679. file_integrity_validation(temp_file, expected_hash)
  680. else:
  681. file_integrity_validation(temp_file, expected_hash)
  682. # put file into to cache
  683. return cache.put_file(file_meta, temp_file)