download.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import hashlib
  15. import os
  16. import os.path as osp
  17. import shutil
  18. import sys
  19. import tarfile
  20. import time
  21. import zipfile
  22. import httpx
  23. try:
  24. from tqdm import tqdm
  25. except:
  26. class tqdm:
  27. def __init__(self, total=None):
  28. self.total = total
  29. self.n = 0
  30. def update(self, n):
  31. self.n += n
  32. if self.total is None:
  33. sys.stderr.write(f"\r{self.n:.1f} bytes")
  34. else:
  35. sys.stderr.write(f"\r{100 * self.n / float(self.total):.1f}%")
  36. sys.stderr.flush()
  37. def __enter__(self):
  38. return self
  39. def __exit__(self, exc_type, exc_val, exc_tb):
  40. sys.stderr.write('\n')
  41. import logging
  42. logger = logging.getLogger(__name__)
  43. __all__ = ['get_weights_path_from_url']
  44. WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/hapi/weights")
  45. DOWNLOAD_RETRY_LIMIT = 3
  46. def is_url(path):
  47. """
  48. Whether path is URL.
  49. Args:
  50. path (string): URL string or not.
  51. """
  52. return path.startswith('http://') or path.startswith('https://')
  53. def get_weights_path_from_url(url, md5sum=None):
  54. """Get weights path from WEIGHT_HOME, if not exists,
  55. download it from url.
  56. Args:
  57. url (str): download url
  58. md5sum (str): md5 sum of download package
  59. Returns:
  60. str: a local path to save downloaded weights.
  61. Examples:
  62. .. code-block:: python
  63. >>> from paddle.utils.download import get_weights_path_from_url
  64. >>> resnet18_pretrained_weight_url = 'https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams'
  65. >>> local_weight_path = get_weights_path_from_url(resnet18_pretrained_weight_url)
  66. """
  67. path = get_path_from_url(url, WEIGHTS_HOME, md5sum)
  68. return path
  69. def _map_path(url, root_dir):
  70. # parse path after download under root_dir
  71. fname = osp.split(url)[-1]
  72. fpath = fname
  73. return osp.join(root_dir, fpath)
  74. def _get_unique_endpoints(trainer_endpoints):
  75. # Sorting is to avoid different environmental variables for each card
  76. trainer_endpoints.sort()
  77. ips = set()
  78. unique_endpoints = set()
  79. for endpoint in trainer_endpoints:
  80. ip = endpoint.split(":")[0]
  81. if ip in ips:
  82. continue
  83. ips.add(ip)
  84. unique_endpoints.add(endpoint)
  85. logger.info(f"unique_endpoints {unique_endpoints}")
  86. return unique_endpoints
  87. def get_path_from_url(
  88. url, root_dir, md5sum=None, check_exist=True, decompress=True, method='get'
  89. ):
  90. """Download from given url to root_dir.
  91. if file or directory specified by url is exists under
  92. root_dir, return the path directly, otherwise download
  93. from url and decompress it, return the path.
  94. Args:
  95. url (str): download url
  96. root_dir (str): root dir for downloading, it should be
  97. WEIGHTS_HOME or DATASET_HOME
  98. md5sum (str): md5 sum of download package
  99. decompress (bool): decompress zip or tar file. Default is `True`
  100. method (str): which download method to use. Support `wget` and `get`. Default is `get`.
  101. Returns:
  102. str: a local path to save downloaded models & weights & datasets.
  103. """
  104. from paddle.distributed import ParallelEnv
  105. assert is_url(url), f"downloading from {url} not a url"
  106. # parse path after download to decompress under root_dir
  107. fullpath = _map_path(url, root_dir)
  108. # Mainly used to solve the problem of downloading data from different
  109. # machines in the case of multiple machines. Different ips will download
  110. # data, and the same ip will only download data once.
  111. unique_endpoints = _get_unique_endpoints(ParallelEnv().trainer_endpoints[:])
  112. if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
  113. logger.info(f"Found {fullpath}")
  114. else:
  115. if ParallelEnv().current_endpoint in unique_endpoints:
  116. fullpath = _download(url, root_dir, md5sum, method=method)
  117. else:
  118. while not os.path.exists(fullpath):
  119. time.sleep(1)
  120. if ParallelEnv().current_endpoint in unique_endpoints:
  121. if decompress and (
  122. tarfile.is_tarfile(fullpath) or zipfile.is_zipfile(fullpath)
  123. ):
  124. fullpath = _decompress(fullpath)
  125. return fullpath
  126. def _get_download(url, fullname):
  127. # using requests.get method
  128. fname = osp.basename(fullname)
  129. try:
  130. with httpx.stream(
  131. "GET", url, timeout=None, follow_redirects=True
  132. ) as req:
  133. if req.status_code != 200:
  134. raise RuntimeError(
  135. f"Downloading from {url} failed with code "
  136. f"{req.status_code}!"
  137. )
  138. tmp_fullname = fullname + "_tmp"
  139. total_size = req.headers.get('content-length')
  140. with open(tmp_fullname, 'wb') as f:
  141. if total_size:
  142. with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
  143. for chunk in req.iter_bytes(chunk_size=1024):
  144. f.write(chunk)
  145. pbar.update(1)
  146. else:
  147. for chunk in req.iter_bytes(chunk_size=1024):
  148. if chunk:
  149. f.write(chunk)
  150. shutil.move(tmp_fullname, fullname)
  151. return fullname
  152. except Exception as e: # requests.exceptions.ConnectionError
  153. logger.info(
  154. f"Downloading {fname} from {url} failed with exception {str(e)}"
  155. )
  156. return False
  157. _download_methods = {'get': _get_download}
  158. def _download(url, path, md5sum=None, method='get'):
  159. """
  160. Download from url, save to path.
  161. url (str): download url
  162. path (str): download to given path
  163. md5sum (str): md5 sum of download package
  164. method (str): which download method to use. Support `wget` and `get`. Default is `get`.
  165. """
  166. assert method in _download_methods, f'make sure `{method}` implemented'
  167. if not osp.exists(path):
  168. os.makedirs(path)
  169. fname = osp.split(url)[-1]
  170. fullname = osp.join(path, fname)
  171. retry_cnt = 0
  172. logger.info(f"Downloading {fname} from {url}")
  173. while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
  174. logger.info(f"md5check {fullname} and {md5sum}")
  175. if retry_cnt < DOWNLOAD_RETRY_LIMIT:
  176. retry_cnt += 1
  177. else:
  178. raise RuntimeError(
  179. f"Download from {url} failed. " "Retry limit reached"
  180. )
  181. if not _download_methods[method](url, fullname):
  182. time.sleep(1)
  183. continue
  184. return fullname
  185. def _md5check(fullname, md5sum=None):
  186. if md5sum is None:
  187. return True
  188. logger.info(f"File {fullname} md5 checking...")
  189. md5 = hashlib.md5()
  190. with open(fullname, 'rb') as f:
  191. for chunk in iter(lambda: f.read(4096), b""):
  192. md5.update(chunk)
  193. calc_md5sum = md5.hexdigest()
  194. if calc_md5sum != md5sum:
  195. logger.info(
  196. f"File {fullname} md5 check failed, {calc_md5sum}(calc) != "
  197. f"{md5sum}(base)"
  198. )
  199. return False
  200. return True
  201. def _decompress(fname):
  202. """
  203. Decompress for zip and tar file
  204. """
  205. logger.info(f"Decompressing {fname}...")
  206. # For protecting decompressing interrupted,
  207. # decompress to fpath_tmp directory firstly, if decompress
  208. # successed, move decompress files to fpath and delete
  209. # fpath_tmp and remove download compress file.
  210. if tarfile.is_tarfile(fname):
  211. uncompressed_path = _uncompress_file_tar(fname)
  212. elif zipfile.is_zipfile(fname):
  213. uncompressed_path = _uncompress_file_zip(fname)
  214. else:
  215. raise TypeError(f"Unsupport compress file type {fname}")
  216. return uncompressed_path
  217. def _uncompress_file_zip(filepath):
  218. with zipfile.ZipFile(filepath, 'r') as files:
  219. file_list_tmp = files.namelist()
  220. file_list = []
  221. for file in file_list_tmp:
  222. file_list.append(file.replace("../", ""))
  223. file_dir = os.path.dirname(filepath)
  224. if _is_a_single_file(file_list):
  225. rootpath = file_list[0]
  226. uncompressed_path = os.path.join(file_dir, rootpath)
  227. files.extractall(file_dir)
  228. elif _is_a_single_dir(file_list):
  229. # `strip(os.sep)` to remove `os.sep` in the tail of path
  230. rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
  231. os.sep
  232. )[-1]
  233. uncompressed_path = os.path.join(file_dir, rootpath)
  234. files.extractall(file_dir)
  235. else:
  236. rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
  237. uncompressed_path = os.path.join(file_dir, rootpath)
  238. if not os.path.exists(uncompressed_path):
  239. os.makedirs(uncompressed_path)
  240. files.extractall(os.path.join(file_dir, rootpath))
  241. return uncompressed_path
  242. def _uncompress_file_tar(filepath, mode="r:*"):
  243. with tarfile.open(filepath, mode) as files:
  244. file_list_tmp = files.getnames()
  245. file_list = []
  246. for file in file_list_tmp:
  247. assert (
  248. file[0] != "/"
  249. ), f"uncompress file path {file} should not start with /"
  250. file_list.append(file.replace("../", ""))
  251. file_dir = os.path.dirname(filepath)
  252. if _is_a_single_file(file_list):
  253. rootpath = file_list[0]
  254. uncompressed_path = os.path.join(file_dir, rootpath)
  255. files.extractall(file_dir)
  256. elif _is_a_single_dir(file_list):
  257. rootpath = os.path.splitext(file_list[0].strip(os.sep))[0].split(
  258. os.sep
  259. )[-1]
  260. uncompressed_path = os.path.join(file_dir, rootpath)
  261. files.extractall(file_dir)
  262. else:
  263. rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1]
  264. uncompressed_path = os.path.join(file_dir, rootpath)
  265. if not os.path.exists(uncompressed_path):
  266. os.makedirs(uncompressed_path)
  267. files.extractall(os.path.join(file_dir, rootpath))
  268. return uncompressed_path
  269. def _is_a_single_file(file_list):
  270. if len(file_list) == 1 and file_list[0].find(os.sep) < 0:
  271. return True
  272. return False
  273. def _is_a_single_dir(file_list):
  274. new_file_list = []
  275. for file_path in file_list:
  276. if '/' in file_path:
  277. file_path = file_path.replace('/', os.sep)
  278. elif '\\' in file_path:
  279. file_path = file_path.replace('\\', os.sep)
  280. new_file_list.append(file_path)
  281. file_name = new_file_list[0].split(os.sep)[0]
  282. for i in range(1, len(new_file_list)):
  283. if file_name != new_file_list[i].split(os.sep)[0]:
  284. return False
  285. return True