# Copyright (c) Alibaba, Inc. and its affiliates. import hashlib import os import pickle import tempfile import threading from shutil import move, rmtree from typing import Dict from modelscope.hub.constants import ( # noqa FILE_HASH, MODELSCOPE_ENABLE_DEFAULT_HASH_VALIDATION) from modelscope.hub.utils.utils import compute_hash from modelscope.utils.logger import get_logger logger = get_logger() enable_default_hash_validation = \ os.getenv(MODELSCOPE_ENABLE_DEFAULT_HASH_VALIDATION, 'False').strip().lower() == 'true' """Implements caching functionality, used internally only """ class FileSystemCache(object): KEY_FILE_NAME = '.msc' MODEL_META_FILE_NAME = '.mdl' MODEL_META_MODEL_ID = 'id' MODEL_VERSION_FILE_NAME = '.mv' """Local file cache. """ def __init__( self, cache_root_location: str, **kwargs, ): """Base file system cache interface. Args: cache_root_location (str): The root location to store files. kwargs(dict): The keyword arguments. """ self._cache_lock = threading.RLock() os.makedirs(cache_root_location, exist_ok=True) self.cache_root_location = cache_root_location self.load_cache() def get_root_location(self): return self.cache_root_location def load_cache(self): self.cached_files = [] cache_keys_file_path = os.path.join(self.cache_root_location, FileSystemCache.KEY_FILE_NAME) if os.path.exists(cache_keys_file_path): with open(cache_keys_file_path, 'rb') as f: self.cached_files = pickle.load(f) def save_cached_files(self): """ Save cache metadata in order to verify that the cached content is consistent with the remote content. Example of the cached content: [{'Path': 'configuration.json', 'Revision': 'f01dxxx'}, {'Path': 'model.bin', 'Revision': '1159xxx'}, ...] """ with self._cache_lock: cache_keys_file_path = os.path.join(self.cache_root_location, FileSystemCache.KEY_FILE_NAME) fd, temp_filename = tempfile.mkstemp( suffix='.tmp', dir=self.cache_root_location) try: with os.fdopen(fd, 'wb') as f: pickle.dump(self.cached_files, f) move(temp_filename, cache_keys_file_path) except Exception: try: os.close(fd) except OSError: pass if os.path.exists(temp_filename): os.unlink(temp_filename) raise def get_file(self, key): """Check the key is in the cache, if exist, return the file, otherwise return None. Args: key(str): The cache key. Raises: None """ pass def put_file(self, key, location): """Put file to the cache. Args: key (str): The cache key location (str): Location of the file, we will move the file to cache. Raises: None """ pass def remove_key(self, key): """Remove cache key in index, The file is removed manually Args: key (dict): The cache key. """ if key in self.cached_files: self.cached_files.remove(key) self.save_cached_files() def exists(self, key): for cache_file in self.cached_files: if cache_file == key: return True return False def clear_cache(self): """Remove all files and metadata from the cache In the case of multiple cache locations, this clears only the last one, which is assumed to be the read/write one. """ rmtree(self.cache_root_location) self.load_cache() def hash_name(self, key): return hashlib.sha256(key.encode()).hexdigest() class ModelFileSystemCache(FileSystemCache): """Local cache file layout cache_root/owner/model_name/individual cached files and cache index file '.mcs' Save only one version for each file. """ def __init__(self, cache_root, owner=None, name=None): """Put file to the cache Args: cache_root(`str`): The modelscope local cache root(default: ~/.cache/modelscope/hub) owner(`str`): The model owner. name('str'): The name of the model Returns: Raises: None model_id = {owner}/{name} """ if owner is None or name is None: # get model meta from super().__init__(os.path.join(cache_root)) self.load_model_meta() else: super().__init__(os.path.join(cache_root, owner, name)) self.model_meta = { FileSystemCache.MODEL_META_MODEL_ID: '%s/%s' % (owner, name) } self.save_model_meta() self.cached_model_revision = self.load_model_version() def load_model_meta(self): meta_file_path = os.path.join(self.cache_root_location, FileSystemCache.MODEL_META_FILE_NAME) if os.path.exists(meta_file_path): with open(meta_file_path, 'rb') as f: self.model_meta = pickle.load(f) else: self.model_meta = {FileSystemCache.MODEL_META_MODEL_ID: 'unknown'} def load_model_version(self): model_version_file_path = os.path.join( self.cache_root_location, FileSystemCache.MODEL_VERSION_FILE_NAME) if os.path.exists(model_version_file_path): with open(model_version_file_path, 'r') as f: return f.read().strip() else: return None def save_model_version(self, revision_info: Dict): model_version_file_path = os.path.join( self.cache_root_location, FileSystemCache.MODEL_VERSION_FILE_NAME) with open(model_version_file_path, 'w') as f: if isinstance(revision_info, dict): version_info_str = 'Revision:%s,CreatedAt:%s' % ( revision_info['Revision'], revision_info['CreatedAt']) f.write(version_info_str) else: f.write(revision_info) def get_model_id(self): return self.model_meta[FileSystemCache.MODEL_META_MODEL_ID] def save_model_meta(self): meta_file_path = os.path.join(self.cache_root_location, FileSystemCache.MODEL_META_FILE_NAME) with open(meta_file_path, 'wb') as f: pickle.dump(self.model_meta, f) def get_file_by_path(self, file_path): """Retrieve the cache if there is file match the path. Args: file_path (str): The file path in the model. Returns: path: the full path of the file. """ for cached_file in self.cached_files: if file_path == cached_file['Path']: cached_file_path = os.path.join(self.cache_root_location, cached_file['Path']) if os.path.exists(cached_file_path): return cached_file_path else: self.remove_key(cached_file) return None def get_file_by_path_and_commit_id(self, file_path, commit_id): """Retrieve the cache if there is file match the path. Args: file_path (str): The file path in the model. commit_id (str): The commit id of the file Returns: path: the full path of the file. """ for cached_file in self.cached_files: if file_path == cached_file['Path'] and \ (cached_file['Revision'].startswith(commit_id) or commit_id.startswith(cached_file['Revision'])): cached_file_path = os.path.join(self.cache_root_location, cached_file['Path']) if os.path.exists(cached_file_path): return cached_file_path else: self.remove_key(cached_file) return None def get_file_by_info(self, model_file_info): """Check if exist cache file. Args: model_file_info (ModelFileInfo): The file information of the file. Returns: str: The file path. """ cache_key = self.__get_cache_key(model_file_info) for cached_file in self.cached_files: if cached_file == cache_key: orig_path = os.path.join(self.cache_root_location, cached_file['Path']) if os.path.exists(orig_path): return orig_path else: self.remove_key(cached_file) break return None def __get_cache_key(self, model_file_info): cache_key = { 'Path': model_file_info['Path'], 'Revision': model_file_info['Revision'], # commit id } return cache_key def exists(self, model_file_info): """Check the file is cached or not. Note existence check will also cover digest check Args: model_file_info (CachedFileInfo): The cached file info Returns: bool: If exists and has the same hash, return True otherwise False """ key = self.__get_cache_key(model_file_info) is_exists = False file_path = key['Path'] cache_file_path = os.path.join(self.cache_root_location, model_file_info['Path']) for cached_key in self.cached_files: if cached_key['Path'] == file_path and ( cached_key['Revision'].startswith(key['Revision']) or key['Revision'].startswith(cached_key['Revision'])): expected_hash = model_file_info[FILE_HASH] if expected_hash is not None and os.path.exists( cache_file_path): # compute hash only when enabled, otherwise just meet expectation by default if enable_default_hash_validation: cache_file_sha256 = compute_hash(cache_file_path) else: cache_file_sha256 = expected_hash if expected_hash == cache_file_sha256: is_exists = True break else: logger.info( f'File [{file_path}] exists in cache but with a mismatched hash, will re-download.' ) if is_exists: if os.path.exists(cache_file_path): return True else: self.remove_key( model_file_info) # someone may manual delete the file return False def remove_if_exists(self, model_file_info): """We in cache, remove it. Args: model_file_info (ModelFileInfo): The model file information from server. """ for cached_file in self.cached_files: if cached_file['Path'] == model_file_info['Path']: self.remove_key(cached_file) file_path = os.path.join(self.cache_root_location, cached_file['Path']) if os.path.exists(file_path): os.remove(file_path) break def put_file(self, model_file_info, model_file_location): """Put model on model_file_location to cache, the model first download to /tmp, and move to cache. Args: model_file_info (str): The file description returned by get_model_files. model_file_location (str): The location of the temporary file. Returns: str: The location of the cached file. """ self.remove_if_exists(model_file_info) # backup old revision cache_key = self.__get_cache_key(model_file_info) cache_full_path = os.path.join( self.cache_root_location, cache_key['Path']) # Branch and Tag do not have same name. cache_file_dir = os.path.dirname(cache_full_path) if not os.path.exists(cache_file_dir): os.makedirs(cache_file_dir, exist_ok=True) # We can't make operation transaction move(model_file_location, cache_full_path) self.cached_files.append(cache_key) self.save_cached_files() return cache_full_path