hf_file_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. # noqa: isort:skip_file, yapf: disable
  2. # Copyright (c) Alibaba, Inc. and its affiliates.
  3. # Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
  4. import json
  5. import os
  6. import re
  7. import copy
  8. import shutil
  9. import time
  10. import warnings
  11. from contextlib import contextmanager
  12. from functools import partial
  13. from pathlib import Path
  14. from typing import Optional, Union
  15. from urllib.parse import urljoin, urlparse
  16. import requests
  17. from tqdm.auto import tqdm
  18. from datasets import config
  19. from datasets.utils.file_utils import hash_url_to_filename, \
  20. get_authentication_headers_for_url, fsspec_head, fsspec_get
  21. from filelock import FileLock
  22. from modelscope.utils.config_ds import MS_DATASETS_CACHE
  23. from modelscope.utils.logger import get_logger
  24. from modelscope.hub.api import ModelScopeConfig
  25. from modelscope import __version__
  26. logger = get_logger()
  27. def get_datasets_user_agent_ms(user_agent: Optional[Union[str, dict]] = None) -> str:
  28. ua = f'datasets/{__version__}'
  29. ua += f'; python/{config.PY_VERSION}'
  30. ua += f'; pyarrow/{config.PYARROW_VERSION}'
  31. if config.TORCH_AVAILABLE:
  32. ua += f'; torch/{config.TORCH_VERSION}'
  33. if config.TF_AVAILABLE:
  34. ua += f'; tensorflow/{config.TF_VERSION}'
  35. if config.JAX_AVAILABLE:
  36. ua += f'; jax/{config.JAX_VERSION}'
  37. if isinstance(user_agent, dict):
  38. ua += f"; {'; '.join(f'{k}/{v}' for k, v in user_agent.items())}"
  39. elif isinstance(user_agent, str):
  40. ua += '; ' + user_agent
  41. return ua
  42. def _request_with_retry_ms(
  43. method: str,
  44. url: str,
  45. max_retries: int = 2,
  46. base_wait_time: float = 0.5,
  47. max_wait_time: float = 2,
  48. timeout: float = 10.0,
  49. **params,
  50. ) -> requests.Response:
  51. """Wrapper around requests to retry in case it fails with a ConnectTimeout, with exponential backoff.
  52. Note that if the environment variable HF_DATASETS_OFFLINE is set to 1, then a OfflineModeIsEnabled error is raised.
  53. Args:
  54. method (str): HTTP method, such as 'GET' or 'HEAD'.
  55. url (str): The URL of the resource to fetch.
  56. max_retries (int): Maximum number of retries, defaults to 0 (no retries).
  57. base_wait_time (float): Duration (in seconds) to wait before retrying the first time. Wait time between
  58. retries then grows exponentially, capped by max_wait_time.
  59. max_wait_time (float): Maximum amount of time between two retries, in seconds.
  60. **params (additional keyword arguments): Params to pass to :obj:`requests.request`.
  61. """
  62. tries, success = 0, False
  63. response = None
  64. while not success:
  65. tries += 1
  66. try:
  67. response = requests.request(method=method.upper(), url=url, timeout=timeout, **params)
  68. success = True
  69. except (requests.exceptions.ConnectTimeout, requests.exceptions.ConnectionError) as err:
  70. if tries > max_retries:
  71. raise err
  72. else:
  73. logger.info(f'{method} request to {url} timed out, retrying... [{tries/max_retries}]')
  74. sleep_time = min(max_wait_time, base_wait_time * 2 ** (tries - 1)) # Exponential backoff
  75. time.sleep(sleep_time)
  76. return response
  77. def http_head_ms(
  78. url, proxies=None, headers=None, cookies=None, allow_redirects=True, timeout=10.0, max_retries=0
  79. ) -> requests.Response:
  80. headers = copy.deepcopy(headers) or {}
  81. headers['user-agent'] = get_datasets_user_agent_ms(user_agent=headers.get('user-agent'))
  82. response = _request_with_retry_ms(
  83. method='HEAD',
  84. url=url,
  85. proxies=proxies,
  86. headers=headers,
  87. cookies=cookies,
  88. allow_redirects=allow_redirects,
  89. timeout=timeout,
  90. max_retries=max_retries,
  91. )
  92. return response
  93. def http_get_ms(
  94. url, temp_file, proxies=None, resume_size=0, headers=None, cookies=None, timeout=100.0, max_retries=0, desc=None
  95. ) -> Optional[requests.Response]:
  96. headers = dict(headers) if headers is not None else {}
  97. headers['user-agent'] = get_datasets_user_agent_ms(user_agent=headers.get('user-agent'))
  98. if resume_size > 0:
  99. headers['Range'] = f'bytes={resume_size:d}-'
  100. response = _request_with_retry_ms(
  101. method='GET',
  102. url=url,
  103. stream=True,
  104. proxies=proxies,
  105. headers=headers,
  106. cookies=cookies,
  107. max_retries=max_retries,
  108. timeout=timeout,
  109. )
  110. if temp_file is None:
  111. return response
  112. if response.status_code == 416: # Range not satisfiable
  113. return
  114. content_length = response.headers.get('Content-Length')
  115. total = resume_size + int(content_length) if content_length is not None else None
  116. progress = tqdm(total=total, initial=resume_size, unit_scale=True, unit='B', desc=desc or 'Downloading')
  117. for chunk in response.iter_content(chunk_size=1024):
  118. progress.update(len(chunk))
  119. temp_file.write(chunk)
  120. progress.close()
  121. def get_from_cache_ms(
  122. url,
  123. cache_dir=None,
  124. force_download=False,
  125. proxies=None,
  126. etag_timeout=100,
  127. resume_download=False,
  128. user_agent=None,
  129. local_files_only=False,
  130. use_etag=True,
  131. max_retries=0,
  132. token=None,
  133. use_auth_token='deprecated',
  134. ignore_url_params=False,
  135. storage_options=None,
  136. download_desc=None,
  137. disable_tqdm=None,
  138. ) -> str:
  139. """
  140. Given a URL, look for the corresponding file in the local cache.
  141. If it's not there, download it. Then return the path to the cached file.
  142. Return:
  143. Local path (string)
  144. Raises:
  145. FileNotFoundError: in case of non-recoverable file
  146. (non-existent or no cache on disk)
  147. ConnectionError: in case of unreachable url
  148. and no cache on disk
  149. """
  150. if use_auth_token != 'deprecated':
  151. warnings.warn(
  152. "'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n"
  153. f"You can remove this warning by passing 'token={use_auth_token}' instead.",
  154. FutureWarning,
  155. )
  156. token = use_auth_token
  157. if cache_dir is None:
  158. cache_dir = MS_DATASETS_CACHE
  159. if isinstance(cache_dir, Path):
  160. cache_dir = str(cache_dir)
  161. os.makedirs(cache_dir, exist_ok=True)
  162. if ignore_url_params:
  163. # strip all query parameters and #fragments from the URL
  164. cached_url = urljoin(url, urlparse(url).path)
  165. else:
  166. cached_url = url # additional parameters may be added to the given URL
  167. connected = False
  168. response = None
  169. cookies = None
  170. etag = None
  171. head_error = None
  172. scheme = None
  173. # Try a first time to file the file on the local file system without eTag (None)
  174. # if we don't ask for 'force_download' then we spare a request
  175. filename = hash_url_to_filename(cached_url, etag=None)
  176. cache_path = os.path.join(cache_dir, filename)
  177. if download_desc is None:
  178. download_desc = 'Downloading [' + filename + ']'
  179. if os.path.exists(cache_path) and not force_download and not use_etag:
  180. return cache_path
  181. # Prepare headers for authentication
  182. headers = get_authentication_headers_for_url(url, token=token)
  183. if user_agent is not None:
  184. headers['user-agent'] = user_agent
  185. # We don't have the file locally or we need an eTag
  186. if not local_files_only:
  187. scheme = urlparse(url).scheme
  188. if scheme not in ('http', 'https'):
  189. response = fsspec_head(url, storage_options=storage_options)
  190. # s3fs uses "ETag", gcsfs uses "etag"
  191. etag = (response.get('ETag', None) or response.get('etag', None)) if use_etag else None
  192. connected = True
  193. try:
  194. cookies = ModelScopeConfig.get_cookies()
  195. response = http_head_ms(
  196. url,
  197. allow_redirects=True,
  198. proxies=proxies,
  199. timeout=etag_timeout,
  200. max_retries=max_retries,
  201. headers=headers,
  202. cookies=cookies,
  203. )
  204. if response.status_code == 200: # ok
  205. etag = response.headers.get('ETag') if use_etag else None
  206. for k, v in response.cookies.items():
  207. # In some edge cases, we need to get a confirmation token
  208. if k.startswith('download_warning') and 'drive.google.com' in url:
  209. url += '&confirm=' + v
  210. cookies = response.cookies
  211. connected = True
  212. # Fix Google Drive URL to avoid Virus scan warning
  213. if 'drive.google.com' in url and 'confirm=' not in url:
  214. url += '&confirm=t'
  215. # In some edge cases, head request returns 400 but the connection is actually ok
  216. elif (
  217. (response.status_code == 400 and 'firebasestorage.googleapis.com' in url)
  218. or (response.status_code == 405 and 'drive.google.com' in url)
  219. or (
  220. response.status_code == 403
  221. and (
  222. re.match(r'^https?://github.com/.*?/.*?/releases/download/.*?/.*?$', url)
  223. or re.match(r'^https://.*?s3.*?amazonaws.com/.*?$', response.url)
  224. )
  225. )
  226. or (response.status_code == 403 and 'ndownloader.figstatic.com' in url)
  227. ):
  228. connected = True
  229. logger.info(f"Couldn't get ETag version for url {url}")
  230. elif response.status_code == 401 and config.HF_ENDPOINT in url and token is None:
  231. raise ConnectionError(
  232. f'Unauthorized for URL {url}. '
  233. f'Please use the parameter `token=True` after logging in with `huggingface-cli login`'
  234. )
  235. except (OSError, requests.exceptions.Timeout) as e:
  236. # not connected
  237. head_error = e
  238. pass
  239. # connected == False = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
  240. # try to get the last downloaded one
  241. if not connected:
  242. if os.path.exists(cache_path) and not force_download:
  243. return cache_path
  244. if local_files_only:
  245. raise FileNotFoundError(
  246. f'Cannot find the requested files in the cached path at {cache_path} and outgoing traffic has been'
  247. " disabled. To enable file online look-ups, set 'local_files_only' to False."
  248. )
  249. elif response is not None and response.status_code == 404:
  250. raise FileNotFoundError(f"Couldn't find file at {url}")
  251. if head_error is not None:
  252. raise ConnectionError(f"Couldn't reach {url} ({repr(head_error)})")
  253. elif response is not None:
  254. raise ConnectionError(f"Couldn't reach {url} (error {response.status_code})")
  255. else:
  256. raise ConnectionError(f"Couldn't reach {url}")
  257. # Try a second time
  258. filename = hash_url_to_filename(cached_url, etag)
  259. cache_path = os.path.join(cache_dir, filename)
  260. if os.path.exists(cache_path) and not force_download:
  261. return cache_path
  262. # From now on, connected is True.
  263. # Prevent parallel downloads of the same file with a lock.
  264. lock_path = cache_path + '.lock'
  265. with FileLock(lock_path):
  266. # Retry in case previously locked processes just enter after the precedent process releases the lock
  267. if os.path.exists(cache_path) and not force_download:
  268. return cache_path
  269. incomplete_path = cache_path + '.incomplete'
  270. @contextmanager
  271. def temp_file_manager(mode='w+b'):
  272. with open(incomplete_path, mode) as f:
  273. yield f
  274. resume_size = 0
  275. if resume_download:
  276. temp_file_manager = partial(temp_file_manager, mode='a+b')
  277. if os.path.exists(incomplete_path):
  278. resume_size = os.stat(incomplete_path).st_size
  279. # Download to temporary file, then copy to cache path once finished.
  280. # Otherwise, you get corrupt cache entries if the download gets interrupted.
  281. with temp_file_manager() as temp_file:
  282. # GET file object
  283. if scheme not in ('http', 'https'):
  284. fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
  285. else:
  286. http_get_ms(
  287. url,
  288. temp_file=temp_file,
  289. proxies=proxies,
  290. resume_size=resume_size,
  291. headers=headers,
  292. cookies=cookies,
  293. max_retries=max_retries,
  294. desc=download_desc,
  295. )
  296. logger.info(f'storing {url} in cache at {cache_path}')
  297. shutil.move(temp_file.name, cache_path)
  298. umask = os.umask(0o666)
  299. os.umask(umask)
  300. os.chmod(cache_path, 0o666 & ~umask)
  301. logger.info(f'creating metadata file for {cache_path}')
  302. meta = {'url': url, 'etag': etag}
  303. meta_path = cache_path + '.json'
  304. with open(meta_path, 'w', encoding='utf-8') as meta_file:
  305. json.dump(meta, meta_file)
  306. return cache_path