caching.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import hashlib
  3. import os
  4. import pickle
  5. import tempfile
  6. import threading
  7. from shutil import move, rmtree
  8. from typing import Dict
  9. from modelscope.hub.constants import ( # noqa
  10. FILE_HASH, MODELSCOPE_ENABLE_DEFAULT_HASH_VALIDATION)
  11. from modelscope.hub.utils.utils import compute_hash
  12. from modelscope.utils.logger import get_logger
  13. logger = get_logger()
  14. enable_default_hash_validation = \
  15. os.getenv(MODELSCOPE_ENABLE_DEFAULT_HASH_VALIDATION, 'False').strip().lower() == 'true'
  16. """Implements caching functionality, used internally only
  17. """
  18. class FileSystemCache(object):
  19. KEY_FILE_NAME = '.msc'
  20. MODEL_META_FILE_NAME = '.mdl'
  21. MODEL_META_MODEL_ID = 'id'
  22. MODEL_VERSION_FILE_NAME = '.mv'
  23. """Local file cache.
  24. """
  25. def __init__(
  26. self,
  27. cache_root_location: str,
  28. **kwargs,
  29. ):
  30. """Base file system cache interface.
  31. Args:
  32. cache_root_location (str): The root location to store files.
  33. kwargs(dict): The keyword arguments.
  34. """
  35. self._cache_lock = threading.RLock()
  36. os.makedirs(cache_root_location, exist_ok=True)
  37. self.cache_root_location = cache_root_location
  38. self.load_cache()
  39. def get_root_location(self):
  40. return self.cache_root_location
  41. def load_cache(self):
  42. self.cached_files = []
  43. cache_keys_file_path = os.path.join(self.cache_root_location,
  44. FileSystemCache.KEY_FILE_NAME)
  45. if os.path.exists(cache_keys_file_path):
  46. with open(cache_keys_file_path, 'rb') as f:
  47. self.cached_files = pickle.load(f)
  48. def save_cached_files(self):
  49. """
  50. Save cache metadata in order to verify that the cached content is consistent with the remote content.
  51. Example of the cached content:
  52. [{'Path': 'configuration.json', 'Revision': 'f01dxxx'}, {'Path': 'model.bin', 'Revision': '1159xxx'}, ...]
  53. """
  54. with self._cache_lock:
  55. cache_keys_file_path = os.path.join(self.cache_root_location,
  56. FileSystemCache.KEY_FILE_NAME)
  57. fd, temp_filename = tempfile.mkstemp(
  58. suffix='.tmp', dir=self.cache_root_location)
  59. try:
  60. with os.fdopen(fd, 'wb') as f:
  61. pickle.dump(self.cached_files, f)
  62. move(temp_filename, cache_keys_file_path)
  63. except Exception:
  64. try:
  65. os.close(fd)
  66. except OSError:
  67. pass
  68. if os.path.exists(temp_filename):
  69. os.unlink(temp_filename)
  70. raise
  71. def get_file(self, key):
  72. """Check the key is in the cache, if exist, return the file, otherwise return None.
  73. Args:
  74. key(str): The cache key.
  75. Raises:
  76. None
  77. """
  78. pass
  79. def put_file(self, key, location):
  80. """Put file to the cache.
  81. Args:
  82. key (str): The cache key
  83. location (str): Location of the file, we will move the file to cache.
  84. Raises:
  85. None
  86. """
  87. pass
  88. def remove_key(self, key):
  89. """Remove cache key in index, The file is removed manually
  90. Args:
  91. key (dict): The cache key.
  92. """
  93. if key in self.cached_files:
  94. self.cached_files.remove(key)
  95. self.save_cached_files()
  96. def exists(self, key):
  97. for cache_file in self.cached_files:
  98. if cache_file == key:
  99. return True
  100. return False
  101. def clear_cache(self):
  102. """Remove all files and metadata from the cache
  103. In the case of multiple cache locations, this clears only the last one,
  104. which is assumed to be the read/write one.
  105. """
  106. rmtree(self.cache_root_location)
  107. self.load_cache()
  108. def hash_name(self, key):
  109. return hashlib.sha256(key.encode()).hexdigest()
  110. class ModelFileSystemCache(FileSystemCache):
  111. """Local cache file layout
  112. cache_root/owner/model_name/individual cached files and cache index file '.mcs'
  113. Save only one version for each file.
  114. """
  115. def __init__(self, cache_root, owner=None, name=None):
  116. """Put file to the cache
  117. Args:
  118. cache_root(`str`): The modelscope local cache root(default: ~/.cache/modelscope/hub)
  119. owner(`str`): The model owner.
  120. name('str'): The name of the model
  121. Returns:
  122. Raises:
  123. None
  124. <Tip>
  125. model_id = {owner}/{name}
  126. </Tip>
  127. """
  128. if owner is None or name is None:
  129. # get model meta from
  130. super().__init__(os.path.join(cache_root))
  131. self.load_model_meta()
  132. else:
  133. super().__init__(os.path.join(cache_root, owner, name))
  134. self.model_meta = {
  135. FileSystemCache.MODEL_META_MODEL_ID: '%s/%s' % (owner, name)
  136. }
  137. self.save_model_meta()
  138. self.cached_model_revision = self.load_model_version()
  139. def load_model_meta(self):
  140. meta_file_path = os.path.join(self.cache_root_location,
  141. FileSystemCache.MODEL_META_FILE_NAME)
  142. if os.path.exists(meta_file_path):
  143. with open(meta_file_path, 'rb') as f:
  144. self.model_meta = pickle.load(f)
  145. else:
  146. self.model_meta = {FileSystemCache.MODEL_META_MODEL_ID: 'unknown'}
  147. def load_model_version(self):
  148. model_version_file_path = os.path.join(
  149. self.cache_root_location, FileSystemCache.MODEL_VERSION_FILE_NAME)
  150. if os.path.exists(model_version_file_path):
  151. with open(model_version_file_path, 'r') as f:
  152. return f.read().strip()
  153. else:
  154. return None
  155. def save_model_version(self, revision_info: Dict):
  156. model_version_file_path = os.path.join(
  157. self.cache_root_location, FileSystemCache.MODEL_VERSION_FILE_NAME)
  158. with open(model_version_file_path, 'w') as f:
  159. if isinstance(revision_info, dict):
  160. version_info_str = 'Revision:%s,CreatedAt:%s' % (
  161. revision_info['Revision'], revision_info['CreatedAt'])
  162. f.write(version_info_str)
  163. else:
  164. f.write(revision_info)
  165. def get_model_id(self):
  166. return self.model_meta[FileSystemCache.MODEL_META_MODEL_ID]
  167. def save_model_meta(self):
  168. meta_file_path = os.path.join(self.cache_root_location,
  169. FileSystemCache.MODEL_META_FILE_NAME)
  170. with open(meta_file_path, 'wb') as f:
  171. pickle.dump(self.model_meta, f)
  172. def get_file_by_path(self, file_path):
  173. """Retrieve the cache if there is file match the path.
  174. Args:
  175. file_path (str): The file path in the model.
  176. Returns:
  177. path: the full path of the file.
  178. """
  179. for cached_file in self.cached_files:
  180. if file_path == cached_file['Path']:
  181. cached_file_path = os.path.join(self.cache_root_location,
  182. cached_file['Path'])
  183. if os.path.exists(cached_file_path):
  184. return cached_file_path
  185. else:
  186. self.remove_key(cached_file)
  187. return None
  188. def get_file_by_path_and_commit_id(self, file_path, commit_id):
  189. """Retrieve the cache if there is file match the path.
  190. Args:
  191. file_path (str): The file path in the model.
  192. commit_id (str): The commit id of the file
  193. Returns:
  194. path: the full path of the file.
  195. """
  196. for cached_file in self.cached_files:
  197. if file_path == cached_file['Path'] and \
  198. (cached_file['Revision'].startswith(commit_id) or commit_id.startswith(cached_file['Revision'])):
  199. cached_file_path = os.path.join(self.cache_root_location,
  200. cached_file['Path'])
  201. if os.path.exists(cached_file_path):
  202. return cached_file_path
  203. else:
  204. self.remove_key(cached_file)
  205. return None
  206. def get_file_by_info(self, model_file_info):
  207. """Check if exist cache file.
  208. Args:
  209. model_file_info (ModelFileInfo): The file information of the file.
  210. Returns:
  211. str: The file path.
  212. """
  213. cache_key = self.__get_cache_key(model_file_info)
  214. for cached_file in self.cached_files:
  215. if cached_file == cache_key:
  216. orig_path = os.path.join(self.cache_root_location,
  217. cached_file['Path'])
  218. if os.path.exists(orig_path):
  219. return orig_path
  220. else:
  221. self.remove_key(cached_file)
  222. break
  223. return None
  224. def __get_cache_key(self, model_file_info):
  225. cache_key = {
  226. 'Path': model_file_info['Path'],
  227. 'Revision': model_file_info['Revision'], # commit id
  228. }
  229. return cache_key
  230. def exists(self, model_file_info):
  231. """Check the file is cached or not. Note existence check will also cover digest check
  232. Args:
  233. model_file_info (CachedFileInfo): The cached file info
  234. Returns:
  235. bool: If exists and has the same hash, return True otherwise False
  236. """
  237. key = self.__get_cache_key(model_file_info)
  238. is_exists = False
  239. file_path = key['Path']
  240. cache_file_path = os.path.join(self.cache_root_location,
  241. model_file_info['Path'])
  242. for cached_key in self.cached_files:
  243. if cached_key['Path'] == file_path and (
  244. cached_key['Revision'].startswith(key['Revision'])
  245. or key['Revision'].startswith(cached_key['Revision'])):
  246. expected_hash = model_file_info[FILE_HASH]
  247. if expected_hash is not None and os.path.exists(
  248. cache_file_path):
  249. # compute hash only when enabled, otherwise just meet expectation by default
  250. if enable_default_hash_validation:
  251. cache_file_sha256 = compute_hash(cache_file_path)
  252. else:
  253. cache_file_sha256 = expected_hash
  254. if expected_hash == cache_file_sha256:
  255. is_exists = True
  256. break
  257. else:
  258. logger.info(
  259. f'File [{file_path}] exists in cache but with a mismatched hash, will re-download.'
  260. )
  261. if is_exists:
  262. if os.path.exists(cache_file_path):
  263. return True
  264. else:
  265. self.remove_key(
  266. model_file_info) # someone may manual delete the file
  267. return False
  268. def remove_if_exists(self, model_file_info):
  269. """We in cache, remove it.
  270. Args:
  271. model_file_info (ModelFileInfo): The model file information from server.
  272. """
  273. for cached_file in self.cached_files:
  274. if cached_file['Path'] == model_file_info['Path']:
  275. self.remove_key(cached_file)
  276. file_path = os.path.join(self.cache_root_location,
  277. cached_file['Path'])
  278. if os.path.exists(file_path):
  279. os.remove(file_path)
  280. break
  281. def put_file(self, model_file_info, model_file_location):
  282. """Put model on model_file_location to cache, the model first download to /tmp, and move to cache.
  283. Args:
  284. model_file_info (str): The file description returned by get_model_files.
  285. model_file_location (str): The location of the temporary file.
  286. Returns:
  287. str: The location of the cached file.
  288. """
  289. self.remove_if_exists(model_file_info) # backup old revision
  290. cache_key = self.__get_cache_key(model_file_info)
  291. cache_full_path = os.path.join(
  292. self.cache_root_location,
  293. cache_key['Path']) # Branch and Tag do not have same name.
  294. cache_file_dir = os.path.dirname(cache_full_path)
  295. if not os.path.exists(cache_file_dir):
  296. os.makedirs(cache_file_dir, exist_ok=True)
  297. # We can't make operation transaction
  298. move(model_file_location, cache_full_path)
  299. self.cached_files.append(cache_key)
  300. self.save_cached_files()
  301. return cache_full_path