caching.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. """
  2. caching tools
  3. """
  4. import hashlib
  5. import os
  6. import pickle
  7. import tempfile
  8. from shutil import move, rmtree
  9. from typing import Dict
  10. from aistudio_sdk.config import ( # noqa
  11. FILE_HASH, AISTUDIO_ENABLE_DEFAULT_HASH_VALIDATION, CACHE_KEY)
  12. from .util import compute_hash
  13. from aistudio_sdk import log
  14. enable_default_hash_validation = \
  15. os.getenv(AISTUDIO_ENABLE_DEFAULT_HASH_VALIDATION, 'False').strip().lower() == 'true'
  16. """Implements caching functionality, used internally only
  17. """
  18. class FileSystemCache(object):
  19. """info"""
  20. KEY_FILE_NAME = '.msc'
  21. MODEL_META_FILE_NAME = '.mdl'
  22. MODEL_META_MODEL_ID = 'id'
  23. MODEL_VERSION_FILE_NAME = '.mv'
  24. """Local file cache.
  25. """
  26. def __init__(
  27. self,
  28. cache_root_location: str,
  29. ):
  30. """Base file system cache interface.
  31. Args:
  32. cache_root_location (str): The root location to store files.
  33. kwargs(dict): The keyword arguments.
  34. """
  35. os.makedirs(cache_root_location, exist_ok=True)
  36. self.cache_root_location = cache_root_location
  37. self.load_cache()
  38. def get_root_location(self):
  39. """location"""
  40. return self.cache_root_location
  41. def load_cache(self):
  42. """load"""
  43. self.cached_files = []
  44. cache_keys_file_path = os.path.join(self.cache_root_location,
  45. FileSystemCache.KEY_FILE_NAME)
  46. if os.path.exists(cache_keys_file_path):
  47. with open(cache_keys_file_path, 'rb') as f:
  48. self.cached_files = pickle.load(f)
  49. def save_cached_files(self):
  50. """Save cache metadata."""
  51. # save new meta to tmp and move to KEY_FILE_NAME
  52. cache_keys_file_path = os.path.join(self.cache_root_location,
  53. FileSystemCache.KEY_FILE_NAME)
  54. fd, fn = tempfile.mkstemp()
  55. with open(fd, 'wb') as f:
  56. pickle.dump(self.cached_files, f)
  57. move(fn, cache_keys_file_path)
  58. def get_file(self, key):
  59. """Check the key is in the cache, if exists, return the file, otherwise return None.
  60. Args:
  61. key(str): The cache key.
  62. Raises:
  63. None
  64. """
  65. pass
  66. def put_file(self, key, location):
  67. """Put file to the cache.
  68. Args:
  69. key (str): The cache key
  70. location (str): Location of the file, we will move the file to cache.
  71. Raises:
  72. None
  73. """
  74. pass
  75. def remove_key(self, key):
  76. """Remove cache key in index, The file is removed manually
  77. Args:
  78. key (dict): The cache key.
  79. """
  80. if key in self.cached_files:
  81. self.cached_files.remove(key)
  82. self.save_cached_files()
  83. def exists(self, key):
  84. """check"""
  85. for cache_file in self.cached_files:
  86. if cache_file == key:
  87. return True
  88. return False
  89. def clear_cache(self):
  90. """Remove all files and metadata from the cache
  91. In the case of multiple cache locations, this clears only the last one,
  92. which is assumed to be the read/write one.
  93. """
  94. rmtree(self.cache_root_location)
  95. self.load_cache()
  96. def hash_name(self, key):
  97. """name"""
  98. return hashlib.sha256(key.encode()).hexdigest()
  99. class ModelFileSystemCache(FileSystemCache):
  100. """Local cache file layout
  101. cache_root/owner/model_name/individual cached files and cache index file '.mcs'
  102. Save only one version for each file.
  103. """
  104. def __init__(self, cache_root, owner=None, name=None):
  105. """Put file to the cache
  106. Args:
  107. cache_root(`str`): The aistudio local cache root(default: ~/.cache/aistudio/)
  108. owner(`str`): The model owner.
  109. name('str'): The name of the model
  110. Returns:
  111. Raises:
  112. None
  113. <Tip>
  114. model_id = {owner}/{name}
  115. </Tip>
  116. """
  117. if owner is None or name is None:
  118. # get model meta from
  119. super().__init__(os.path.join(cache_root))
  120. self.load_model_meta()
  121. else:
  122. super().__init__(os.path.join(cache_root, owner, name))
  123. self.model_meta = {
  124. FileSystemCache.MODEL_META_MODEL_ID: '%s/%s' % (owner, name)
  125. }
  126. self.save_model_meta()
  127. self.cached_model_revision = self.load_model_version()
  128. def load_model_meta(self):
  129. """get meta"""
  130. meta_file_path = os.path.join(self.cache_root_location,
  131. FileSystemCache.MODEL_META_FILE_NAME)
  132. if os.path.exists(meta_file_path):
  133. with open(meta_file_path, 'rb') as f:
  134. self.model_meta = pickle.load(f)
  135. else:
  136. self.model_meta = {FileSystemCache.MODEL_META_MODEL_ID: 'unknown'}
  137. def load_model_version(self):
  138. """use version info"""
  139. model_version_file_path = os.path.join(
  140. self.cache_root_location, FileSystemCache.MODEL_VERSION_FILE_NAME)
  141. if os.path.exists(model_version_file_path):
  142. with open(model_version_file_path, 'r') as f:
  143. return f.read().strip()
  144. else:
  145. return None
  146. def save_model_version(self, revision_info: Dict):
  147. """save file info"""
  148. model_version_file_path = os.path.join(
  149. self.cache_root_location, FileSystemCache.MODEL_VERSION_FILE_NAME)
  150. with open(model_version_file_path, 'w') as f:
  151. if isinstance(revision_info, dict):
  152. version_info_str = 'Revision:%s,CreatedAt:%s' % (
  153. revision_info['Revision'], revision_info.get('CreatedAt') or 'unknown')
  154. f.write(version_info_str)
  155. else:
  156. f.write(revision_info)
  157. def get_model_id(self):
  158. """get"""
  159. return self.model_meta[FileSystemCache.MODEL_META_MODEL_ID]
  160. def save_model_meta(self):
  161. """save"""
  162. meta_file_path = os.path.join(self.cache_root_location,
  163. FileSystemCache.MODEL_META_FILE_NAME)
  164. with open(meta_file_path, 'wb') as f:
  165. pickle.dump(self.model_meta, f)
  166. def get_file_by_path(self, file_path):
  167. """Retrieve the cache if there is file match the path.
  168. Args:
  169. file_path (str): The file path in the model.
  170. Returns:
  171. path: the full path of the file.
  172. """
  173. for cached_file in self.cached_files:
  174. if file_path == cached_file['Path']:
  175. cached_file_path = os.path.join(self.cache_root_location,
  176. cached_file['Path'])
  177. if os.path.exists(cached_file_path):
  178. return cached_file_path
  179. else:
  180. self.remove_key(cached_file)
  181. return None
  182. def get_file_by_path_and_commit_id(self, file_path, commit_id):
  183. """Retrieve the cache if there is file match the path.
  184. Args:
  185. file_path (str): The file path in the model.
  186. commit_id (str): The commit id of the file
  187. Returns:
  188. path: the full path of the file.
  189. """
  190. for cached_file in self.cached_files:
  191. if file_path == cached_file['Path'] and \
  192. (cached_file['Revision'].startswith(commit_id) or commit_id.startswith(cached_file['Revision'])):
  193. cached_file_path = os.path.join(self.cache_root_location,
  194. cached_file['Path'])
  195. if os.path.exists(cached_file_path):
  196. return cached_file_path
  197. else:
  198. self.remove_key(cached_file)
  199. return None
  200. def get_file_by_info(self, model_file_info):
  201. """Check if exist cache file.
  202. Args:
  203. model_file_info (ModelFileInfo): The file information of the file.
  204. Returns:
  205. str: The file path.
  206. """
  207. cache_key = self.__get_cache_key(model_file_info)
  208. for cached_file in self.cached_files:
  209. if cached_file == cache_key:
  210. orig_path = os.path.join(self.cache_root_location,
  211. cached_file['Path'])
  212. if os.path.exists(orig_path):
  213. return orig_path
  214. else:
  215. self.remove_key(cached_file)
  216. break
  217. return None
  218. def __get_cache_key(self, model_file_info):
  219. cache_key = {
  220. 'Path': model_file_info['path'],
  221. 'Revision': model_file_info['sha'], # commit id
  222. }
  223. return cache_key
  224. def exists(self, model_file_info):
  225. """Check the file is cached or not. Note existence check will also cover digest check
  226. Args:
  227. model_file_info (CachedFileInfo): The cached file info
  228. Returns:
  229. bool: If exists and has the same hash, return True otherwise False
  230. """
  231. key = self.__get_cache_key(model_file_info)
  232. is_exists = False
  233. file_path = key['Path']
  234. cache_file_path = os.path.join(self.cache_root_location,
  235. model_file_info['path'])
  236. for cached_key in self.cached_files:
  237. if cached_key['Path'] == file_path and (
  238. cached_key['Revision'].startswith(key['Revision'])
  239. or key['Revision'].startswith(cached_key['Revision'])):
  240. expected_hash = model_file_info[CACHE_KEY]
  241. if expected_hash is not None and os.path.exists(
  242. cache_file_path):
  243. # compute hash only when enabled, otherwise just meet expectation by default
  244. if enable_default_hash_validation:
  245. cache_file_sha256 = compute_hash(cache_file_path)
  246. else:
  247. cache_file_sha256 = expected_hash
  248. if expected_hash == cache_file_sha256:
  249. is_exists = True
  250. break
  251. else:
  252. log.info(
  253. f'File [{file_path}] exists in cache but with a mismatched hash, will re-download.'
  254. )
  255. if is_exists:
  256. if os.path.exists(cache_file_path):
  257. return True
  258. else:
  259. self.remove_key(
  260. model_file_info)
  261. return False
  262. def remove_if_exists(self, model_file_info):
  263. """We in cache, remove it.
  264. Args:
  265. model_file_info (ModelFileInfo): The model file information from server.
  266. """
  267. for cached_file in self.cached_files:
  268. if cached_file['Path'] == model_file_info['path']:
  269. self.remove_key(cached_file)
  270. file_path = os.path.join(self.cache_root_location,
  271. cached_file['Path'])
  272. if os.path.exists(file_path):
  273. os.remove(file_path)
  274. break
  275. def put_file(self, model_file_info, model_file_location):
  276. """Put model on model_file_location to cache, the model first download to /tmp, and move to cache.
  277. Args:
  278. model_file_info (str): The file description returned by get_model_files.
  279. model_file_location (str): The location of the temporary file.
  280. Returns:
  281. str: The location of the cached file.
  282. """
  283. self.remove_if_exists(model_file_info) # backup old revision
  284. cache_key = self.__get_cache_key(model_file_info)
  285. cache_full_path = os.path.join(
  286. self.cache_root_location,
  287. cache_key['Path']) # Branch and Tag do not have same name.
  288. cache_file_dir = os.path.dirname(cache_full_path)
  289. if not os.path.exists(cache_file_dir):
  290. os.makedirs(cache_file_dir, exist_ok=True)
  291. # We can't make operation transaction
  292. move(model_file_location, cache_full_path)
  293. self.cached_files.append(cache_key)
  294. self.save_cached_files()
  295. return cache_full_path