common.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import errno
  15. import glob
  16. import hashlib
  17. import importlib
  18. import os
  19. import pickle
  20. import re
  21. import shutil
  22. import sys
  23. import tempfile
  24. import httpx
  25. import paddle
  26. import paddle.dataset
  27. __all__ = []
  28. HOME = os.path.expanduser('~')
  29. # If the default HOME dir does not support writing, we
  30. # will create a temporary folder to store the cache files.
  31. if not os.access(HOME, os.W_OK):
  32. r"""
  33. gettempdir() return the name of the directory used for temporary files.
  34. On Windows, the directories C:\TEMP, C:\TMP, \TEMP, and \TMP, in that order.
  35. On all other platforms, the directories /tmp, /var/tmp, and /usr/tmp, in that order.
  36. For more details, please refer to https://docs.python.org/3/library/tempfile.html
  37. """
  38. HOME = tempfile.gettempdir()
  39. DATA_HOME = os.path.join(HOME, '.cache', 'paddle', 'dataset')
  40. # When running unit tests, there could be multiple processes that
  41. # trying to create DATA_HOME directory simultaneously, so we cannot
  42. # use a if condition to check for the existence of the directory;
  43. # instead, we use the filesystem as the synchronization mechanism by
  44. # catching returned errors.
  45. def must_mkdirs(path):
  46. try:
  47. os.makedirs(DATA_HOME)
  48. except OSError as exc:
  49. if exc.errno != errno.EEXIST:
  50. raise
  51. must_mkdirs(DATA_HOME)
  52. def md5file(fname):
  53. hash_md5 = hashlib.md5()
  54. f = open(fname, "rb")
  55. for chunk in iter(lambda: f.read(4096), b""):
  56. hash_md5.update(chunk)
  57. f.close()
  58. return hash_md5.hexdigest()
  59. def download(url, module_name, md5sum, save_name=None):
  60. module_name = re.match("^[a-zA-Z0-9_/\\-]+$", module_name).group()
  61. if isinstance(save_name, str):
  62. save_name = re.match(
  63. "^(?:(?!\\.\\.)[a-zA-Z0-9_/\\.-])+$", save_name
  64. ).group()
  65. dirname = os.path.join(DATA_HOME, module_name)
  66. if not os.path.exists(dirname):
  67. os.makedirs(dirname)
  68. filename = os.path.join(
  69. dirname, url.split('/')[-1] if save_name is None else save_name
  70. )
  71. if os.path.exists(filename) and md5file(filename) == md5sum:
  72. return filename
  73. retry = 0
  74. retry_limit = 3
  75. while not (os.path.exists(filename) and md5file(filename) == md5sum):
  76. if os.path.exists(filename):
  77. sys.stderr.write(f"file {md5file(filename)} md5 {md5sum}\n")
  78. if retry < retry_limit:
  79. retry += 1
  80. else:
  81. raise RuntimeError(
  82. f"Cannot download {url} within retry limit {retry_limit}"
  83. )
  84. sys.stderr.write(
  85. f"Cache file {filename} not found, downloading {url} \n"
  86. )
  87. sys.stderr.write("Begin to download\n")
  88. try:
  89. # (risemeup1):use httpx to replace requests
  90. with httpx.stream(
  91. "GET", url, timeout=None, follow_redirects=True
  92. ) as r:
  93. total_length = r.headers.get('content-length')
  94. if total_length is None:
  95. with open(filename, 'wb') as f:
  96. shutil.copyfileobj(r.raw, f)
  97. else:
  98. with open(filename, 'wb') as f:
  99. chunk_size = 4096
  100. total_length = int(total_length)
  101. total_iter = total_length / chunk_size + 1
  102. log_interval = (
  103. total_iter // 20 if total_iter > 20 else 1
  104. )
  105. log_index = 0
  106. bar = paddle.hapi.progressbar.ProgressBar(
  107. total_iter, name='item'
  108. )
  109. for data in r.iter_bytes(chunk_size=chunk_size):
  110. f.write(data)
  111. log_index += 1
  112. bar.update(log_index, {})
  113. if log_index % log_interval == 0:
  114. bar.update(log_index)
  115. except Exception as e:
  116. # re-try
  117. continue
  118. sys.stderr.write("\nDownload finished\n")
  119. sys.stdout.flush()
  120. return filename
  121. def fetch_all():
  122. for module_name in [
  123. x for x in dir(paddle.dataset) if not x.startswith("__")
  124. ]:
  125. if "fetch" in dir(
  126. importlib.import_module("paddle.dataset.%s" % module_name)
  127. ):
  128. importlib.import_module('paddle.dataset.%s' % module_name).fetch()
  129. def split(reader, line_count, suffix="%05d.pickle", dumper=pickle.dump):
  130. """
  131. you can call the function as:
  132. split(paddle.dataset.cifar.train10(), line_count=1000,
  133. suffix="imikolov-train-%05d.pickle")
  134. the output files as:
  135. |-imikolov-train-00000.pickle
  136. |-imikolov-train-00001.pickle
  137. |- ...
  138. |-imikolov-train-00480.pickle
  139. :param reader: is a reader creator
  140. :param line_count: line count for each file
  141. :param suffix: the suffix for the output files, should contain "%d"
  142. means the id for each file. Default is "%05d.pickle"
  143. :param dumper: is a callable function that dump object to file, this
  144. function will be called as dumper(obj, f) and obj is the object
  145. will be dumped, f is a file object. Default is cPickle.dump.
  146. """
  147. if not callable(dumper):
  148. raise TypeError("dumper should be callable.")
  149. lines = []
  150. indx_f = 0
  151. for i, d in enumerate(reader()):
  152. lines.append(d)
  153. if i >= line_count and i % line_count == 0:
  154. with open(suffix % indx_f, "w") as f:
  155. dumper(lines, f)
  156. lines = []
  157. indx_f += 1
  158. if lines:
  159. with open(suffix % indx_f, "w") as f:
  160. dumper(lines, f)
  161. def cluster_files_reader(
  162. files_pattern, trainer_count, trainer_id, loader=pickle.load
  163. ):
  164. """
  165. Create a reader that yield element from the given files, select
  166. a file set according trainer count and trainer_id
  167. :param files_pattern: the files which generating by split(...)
  168. :param trainer_count: total trainer count
  169. :param trainer_id: the trainer rank id
  170. :param loader: is a callable function that load object from file, this
  171. function will be called as loader(f) and f is a file object.
  172. Default is cPickle.load
  173. """
  174. def reader():
  175. if not callable(loader):
  176. raise TypeError("loader should be callable.")
  177. file_list = glob.glob(files_pattern)
  178. file_list.sort()
  179. my_file_list = []
  180. for idx, fn in enumerate(file_list):
  181. if idx % trainer_count == trainer_id:
  182. print("append file: %s" % fn)
  183. my_file_list.append(fn)
  184. for fn in my_file_list:
  185. with open(fn, "r") as f:
  186. lines = loader(f)
  187. yield from lines
  188. return reader
  189. def _check_exists_and_download(path, url, md5, module_name, download=True):
  190. if path and os.path.exists(path):
  191. return path
  192. if download:
  193. return paddle.dataset.common.download(url, module_name, md5)
  194. else:
  195. raise ValueError(f'{path} not exists and auto download disabled')