file_utils.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import hashlib
  3. import inspect
  4. import io
  5. import os
  6. from pathlib import Path
  7. from shutil import Error, copy2, copystat
  8. from typing import BinaryIO, Optional, Union
  9. from urllib.parse import urlparse
  10. # TODO: remove this api, unify to flattened args
  11. def func_receive_dict_inputs(func):
  12. """to decide if a func could receive dict inputs or not
  13. Args:
  14. func (class): the target function to be inspected
  15. Returns:
  16. bool: if func only has one arg ``input`` or ``inputs``, return True, else return False
  17. """
  18. full_args_spec = inspect.getfullargspec(func)
  19. varargs = full_args_spec.varargs
  20. varkw = full_args_spec.varkw
  21. if not (varargs is None and varkw is None):
  22. return False
  23. args = [] if not full_args_spec.args else full_args_spec.args
  24. args.pop(0) if (args and args[0] in ['self', 'cls']) else args
  25. if len(args) == 1 and args[0] in ['input', 'inputs']:
  26. return True
  27. return False
  28. def get_default_modelscope_cache_dir():
  29. """
  30. default base dir: '~/.cache/modelscope'
  31. """
  32. default_cache_dir = os.path.expanduser(Path.home().joinpath(
  33. '.cache', 'modelscope', 'hub'))
  34. return default_cache_dir
  35. def get_modelscope_cache_dir() -> str:
  36. """Get modelscope cache dir, default location or
  37. setting with MODELSCOPE_CACHE
  38. Returns:
  39. str: the modelscope cache root.
  40. """
  41. return os.path.expanduser(
  42. os.getenv('MODELSCOPE_CACHE', get_default_modelscope_cache_dir()))
  43. def get_model_cache_root() -> str:
  44. """Get model cache root path.
  45. Returns:
  46. str: the modelscope model cache root.
  47. """
  48. return os.path.join(get_modelscope_cache_dir(), 'models')
  49. def get_dataset_cache_root() -> str:
  50. """Get dataset raw file cache root path.
  51. if `MODELSCOPE_CACHE` is set, return `MODELSCOPE_CACHE/datasets`,
  52. else return `~/.cache/modelscope/hub/datasets`
  53. Returns:
  54. str: the modelscope dataset raw file cache root.
  55. """
  56. return os.path.join(get_modelscope_cache_dir(), 'datasets')
  57. def get_dataset_cache_dir(dataset_id: str) -> str:
  58. """Get the dataset_id's path.
  59. dataset_cache_root/dataset_id.
  60. Args:
  61. dataset_id (str): The dataset id.
  62. Returns:
  63. str: The dataset_id's cache root path.
  64. """
  65. dataset_root = get_dataset_cache_root()
  66. return dataset_root if dataset_id is None else os.path.join(
  67. dataset_root, dataset_id + '/')
  68. def get_model_cache_dir(model_id: str) -> str:
  69. """cache dir precedence:
  70. function parameter > environment > ~/.cache/modelscope/hub/model_id
  71. Args:
  72. model_id (str, optional): The model id.
  73. Returns:
  74. str: the model_id dir if model_id not None, otherwise cache root dir.
  75. """
  76. root_path = get_model_cache_root()
  77. return root_path if model_id is None else os.path.join(
  78. root_path, model_id + '/')
  79. def read_file(path):
  80. with open(path, 'r') as f:
  81. text = f.read()
  82. return text
  83. def copytree_py37(src,
  84. dst,
  85. symlinks=False,
  86. ignore=None,
  87. copy_function=copy2,
  88. ignore_dangling_symlinks=False,
  89. dirs_exist_ok=False):
  90. """copy from py37 shutil. add the parameter dirs_exist_ok."""
  91. names = os.listdir(src)
  92. if ignore is not None:
  93. ignored_names = ignore(src, names)
  94. else:
  95. ignored_names = set()
  96. os.makedirs(dst, exist_ok=dirs_exist_ok)
  97. errors = []
  98. for name in names:
  99. if name in ignored_names:
  100. continue
  101. srcname = os.path.join(src, name)
  102. dstname = os.path.join(dst, name)
  103. try:
  104. if os.path.islink(srcname):
  105. linkto = os.readlink(srcname)
  106. if symlinks:
  107. # We can't just leave it to `copy_function` because legacy
  108. # code with a custom `copy_function` may rely on copytree
  109. # doing the right thing.
  110. os.symlink(linkto, dstname)
  111. copystat(srcname, dstname, follow_symlinks=not symlinks)
  112. else:
  113. # ignore dangling symlink if the flag is on
  114. if not os.path.exists(linkto) and ignore_dangling_symlinks:
  115. continue
  116. # otherwise let the copy occurs. copy2 will raise an error
  117. if os.path.isdir(srcname):
  118. copytree_py37(
  119. srcname,
  120. dstname,
  121. symlinks,
  122. ignore,
  123. copy_function,
  124. dirs_exist_ok=dirs_exist_ok)
  125. else:
  126. copy_function(srcname, dstname)
  127. elif os.path.isdir(srcname):
  128. copytree_py37(
  129. srcname,
  130. dstname,
  131. symlinks,
  132. ignore,
  133. copy_function,
  134. dirs_exist_ok=dirs_exist_ok)
  135. else:
  136. # Will raise a SpecialFileError for unsupported file types
  137. copy_function(srcname, dstname)
  138. # catch the Error from the recursive copytree so that we can
  139. # continue with other files
  140. except Error as err:
  141. errors.extend(err.args[0])
  142. except OSError as why:
  143. errors.append((srcname, dstname, str(why)))
  144. try:
  145. copystat(src, dst)
  146. except OSError as why:
  147. # Copying file access times may fail on Windows
  148. if getattr(why, 'winerror', None) is None:
  149. errors.append((src, dst, str(why)))
  150. if errors:
  151. raise Error(errors)
  152. return dst
  153. def get_file_size(file_path_or_obj: Union[str, Path, bytes, BinaryIO]) -> int:
  154. if isinstance(file_path_or_obj, (str, Path)):
  155. file_path = Path(file_path_or_obj)
  156. return file_path.stat().st_size
  157. elif isinstance(file_path_or_obj, bytes):
  158. return len(file_path_or_obj)
  159. elif isinstance(file_path_or_obj, io.BufferedIOBase):
  160. current_position = file_path_or_obj.tell()
  161. file_path_or_obj.seek(0, os.SEEK_END)
  162. size = file_path_or_obj.tell()
  163. file_path_or_obj.seek(current_position)
  164. return size
  165. else:
  166. raise TypeError(
  167. 'Unsupported type: must be string, Path, bytes, or io.BufferedIOBase'
  168. )
  169. def get_file_hash(
  170. file_path_or_obj: Union[str, Path, bytes, BinaryIO],
  171. buffer_size_mb: Optional[int] = 1,
  172. tqdm_desc: Optional[str] = '[Calculating]',
  173. disable_tqdm: Optional[bool] = True,
  174. ) -> dict:
  175. from tqdm.auto import tqdm
  176. file_size = get_file_size(file_path_or_obj)
  177. if file_size > 1024 * 1024 * 1024: # 1GB
  178. disable_tqdm = False
  179. name = 'Large File'
  180. if isinstance(file_path_or_obj, (str, Path)):
  181. path = file_path_or_obj if isinstance(
  182. file_path_or_obj, Path) else Path(file_path_or_obj)
  183. name = path.name
  184. tqdm_desc = f'[Validating Hash for {name}]'
  185. buffer_size = buffer_size_mb * 1024 * 1024
  186. file_hash = hashlib.sha256()
  187. chunk_hash_list = []
  188. progress = tqdm(
  189. total=file_size,
  190. initial=0,
  191. unit_scale=True,
  192. dynamic_ncols=True,
  193. unit='B',
  194. desc=tqdm_desc,
  195. disable=disable_tqdm,
  196. )
  197. if isinstance(file_path_or_obj, (str, Path)):
  198. with open(file_path_or_obj, 'rb') as f:
  199. while byte_chunk := f.read(buffer_size):
  200. chunk_hash_list.append(hashlib.sha256(byte_chunk).hexdigest())
  201. file_hash.update(byte_chunk)
  202. progress.update(len(byte_chunk))
  203. file_hash = file_hash.hexdigest()
  204. final_chunk_size = buffer_size
  205. elif isinstance(file_path_or_obj, bytes):
  206. file_hash.update(file_path_or_obj)
  207. file_hash = file_hash.hexdigest()
  208. chunk_hash_list.append(file_hash)
  209. final_chunk_size = len(file_path_or_obj)
  210. progress.update(final_chunk_size)
  211. elif isinstance(file_path_or_obj, io.BufferedIOBase):
  212. file_path_or_obj.seek(0, os.SEEK_SET)
  213. while byte_chunk := file_path_or_obj.read(buffer_size):
  214. chunk_hash_list.append(hashlib.sha256(byte_chunk).hexdigest())
  215. file_hash.update(byte_chunk)
  216. progress.update(len(byte_chunk))
  217. file_hash = file_hash.hexdigest()
  218. final_chunk_size = buffer_size
  219. file_path_or_obj.seek(0, os.SEEK_SET)
  220. else:
  221. progress.close()
  222. raise ValueError(
  223. 'Input must be str, Path, bytes or a io.BufferedIOBase')
  224. progress.close()
  225. return {
  226. 'file_path_or_obj': file_path_or_obj,
  227. 'file_hash': file_hash,
  228. 'file_size': file_size,
  229. 'chunk_size': final_chunk_size,
  230. 'chunk_nums': len(chunk_hash_list),
  231. 'chunk_hash_list': chunk_hash_list,
  232. }
  233. def is_relative_path(url_or_filename: str) -> bool:
  234. """
  235. Check if a given string is a relative path.
  236. """
  237. return urlparse(
  238. url_or_filename).scheme == '' and not os.path.isabs(url_or_filename)