| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import hashlib
- import os
- import pickle
- import tempfile
- import threading
- from shutil import move, rmtree
- from typing import Dict
- from modelscope.hub.constants import ( # noqa
- FILE_HASH, MODELSCOPE_ENABLE_DEFAULT_HASH_VALIDATION)
- from modelscope.hub.utils.utils import compute_hash
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- enable_default_hash_validation = \
- os.getenv(MODELSCOPE_ENABLE_DEFAULT_HASH_VALIDATION, 'False').strip().lower() == 'true'
- """Implements caching functionality, used internally only
- """
- class FileSystemCache(object):
- KEY_FILE_NAME = '.msc'
- MODEL_META_FILE_NAME = '.mdl'
- MODEL_META_MODEL_ID = 'id'
- MODEL_VERSION_FILE_NAME = '.mv'
- """Local file cache.
- """
- def __init__(
- self,
- cache_root_location: str,
- **kwargs,
- ):
- """Base file system cache interface.
- Args:
- cache_root_location (str): The root location to store files.
- kwargs(dict): The keyword arguments.
- """
- self._cache_lock = threading.RLock()
- os.makedirs(cache_root_location, exist_ok=True)
- self.cache_root_location = cache_root_location
- self.load_cache()
- def get_root_location(self):
- return self.cache_root_location
- def load_cache(self):
- self.cached_files = []
- cache_keys_file_path = os.path.join(self.cache_root_location,
- FileSystemCache.KEY_FILE_NAME)
- if os.path.exists(cache_keys_file_path):
- with open(cache_keys_file_path, 'rb') as f:
- self.cached_files = pickle.load(f)
- def save_cached_files(self):
- """
- Save cache metadata in order to verify that the cached content is consistent with the remote content.
- Example of the cached content:
- [{'Path': 'configuration.json', 'Revision': 'f01dxxx'}, {'Path': 'model.bin', 'Revision': '1159xxx'}, ...]
- """
- with self._cache_lock:
- cache_keys_file_path = os.path.join(self.cache_root_location,
- FileSystemCache.KEY_FILE_NAME)
- fd, temp_filename = tempfile.mkstemp(
- suffix='.tmp', dir=self.cache_root_location)
- try:
- with os.fdopen(fd, 'wb') as f:
- pickle.dump(self.cached_files, f)
- move(temp_filename, cache_keys_file_path)
- except Exception:
- try:
- os.close(fd)
- except OSError:
- pass
- if os.path.exists(temp_filename):
- os.unlink(temp_filename)
- raise
- def get_file(self, key):
- """Check the key is in the cache, if exist, return the file, otherwise return None.
- Args:
- key(str): The cache key.
- Raises:
- None
- """
- pass
- def put_file(self, key, location):
- """Put file to the cache.
- Args:
- key (str): The cache key
- location (str): Location of the file, we will move the file to cache.
- Raises:
- None
- """
- pass
- def remove_key(self, key):
- """Remove cache key in index, The file is removed manually
- Args:
- key (dict): The cache key.
- """
- if key in self.cached_files:
- self.cached_files.remove(key)
- self.save_cached_files()
- def exists(self, key):
- for cache_file in self.cached_files:
- if cache_file == key:
- return True
- return False
- def clear_cache(self):
- """Remove all files and metadata from the cache
- In the case of multiple cache locations, this clears only the last one,
- which is assumed to be the read/write one.
- """
- rmtree(self.cache_root_location)
- self.load_cache()
- def hash_name(self, key):
- return hashlib.sha256(key.encode()).hexdigest()
- class ModelFileSystemCache(FileSystemCache):
- """Local cache file layout
- cache_root/owner/model_name/individual cached files and cache index file '.mcs'
- Save only one version for each file.
- """
- def __init__(self, cache_root, owner=None, name=None):
- """Put file to the cache
- Args:
- cache_root(`str`): The modelscope local cache root(default: ~/.cache/modelscope/hub)
- owner(`str`): The model owner.
- name('str'): The name of the model
- Returns:
- Raises:
- None
- <Tip>
- model_id = {owner}/{name}
- </Tip>
- """
- if owner is None or name is None:
- # get model meta from
- super().__init__(os.path.join(cache_root))
- self.load_model_meta()
- else:
- super().__init__(os.path.join(cache_root, owner, name))
- self.model_meta = {
- FileSystemCache.MODEL_META_MODEL_ID: '%s/%s' % (owner, name)
- }
- self.save_model_meta()
- self.cached_model_revision = self.load_model_version()
- def load_model_meta(self):
- meta_file_path = os.path.join(self.cache_root_location,
- FileSystemCache.MODEL_META_FILE_NAME)
- if os.path.exists(meta_file_path):
- with open(meta_file_path, 'rb') as f:
- self.model_meta = pickle.load(f)
- else:
- self.model_meta = {FileSystemCache.MODEL_META_MODEL_ID: 'unknown'}
- def load_model_version(self):
- model_version_file_path = os.path.join(
- self.cache_root_location, FileSystemCache.MODEL_VERSION_FILE_NAME)
- if os.path.exists(model_version_file_path):
- with open(model_version_file_path, 'r') as f:
- return f.read().strip()
- else:
- return None
- def save_model_version(self, revision_info: Dict):
- model_version_file_path = os.path.join(
- self.cache_root_location, FileSystemCache.MODEL_VERSION_FILE_NAME)
- with open(model_version_file_path, 'w') as f:
- if isinstance(revision_info, dict):
- version_info_str = 'Revision:%s,CreatedAt:%s' % (
- revision_info['Revision'], revision_info['CreatedAt'])
- f.write(version_info_str)
- else:
- f.write(revision_info)
- def get_model_id(self):
- return self.model_meta[FileSystemCache.MODEL_META_MODEL_ID]
- def save_model_meta(self):
- meta_file_path = os.path.join(self.cache_root_location,
- FileSystemCache.MODEL_META_FILE_NAME)
- with open(meta_file_path, 'wb') as f:
- pickle.dump(self.model_meta, f)
- def get_file_by_path(self, file_path):
- """Retrieve the cache if there is file match the path.
- Args:
- file_path (str): The file path in the model.
- Returns:
- path: the full path of the file.
- """
- for cached_file in self.cached_files:
- if file_path == cached_file['Path']:
- cached_file_path = os.path.join(self.cache_root_location,
- cached_file['Path'])
- if os.path.exists(cached_file_path):
- return cached_file_path
- else:
- self.remove_key(cached_file)
- return None
- def get_file_by_path_and_commit_id(self, file_path, commit_id):
- """Retrieve the cache if there is file match the path.
- Args:
- file_path (str): The file path in the model.
- commit_id (str): The commit id of the file
- Returns:
- path: the full path of the file.
- """
- for cached_file in self.cached_files:
- if file_path == cached_file['Path'] and \
- (cached_file['Revision'].startswith(commit_id) or commit_id.startswith(cached_file['Revision'])):
- cached_file_path = os.path.join(self.cache_root_location,
- cached_file['Path'])
- if os.path.exists(cached_file_path):
- return cached_file_path
- else:
- self.remove_key(cached_file)
- return None
- def get_file_by_info(self, model_file_info):
- """Check if exist cache file.
- Args:
- model_file_info (ModelFileInfo): The file information of the file.
- Returns:
- str: The file path.
- """
- cache_key = self.__get_cache_key(model_file_info)
- for cached_file in self.cached_files:
- if cached_file == cache_key:
- orig_path = os.path.join(self.cache_root_location,
- cached_file['Path'])
- if os.path.exists(orig_path):
- return orig_path
- else:
- self.remove_key(cached_file)
- break
- return None
- def __get_cache_key(self, model_file_info):
- cache_key = {
- 'Path': model_file_info['Path'],
- 'Revision': model_file_info['Revision'], # commit id
- }
- return cache_key
- def exists(self, model_file_info):
- """Check the file is cached or not. Note existence check will also cover digest check
- Args:
- model_file_info (CachedFileInfo): The cached file info
- Returns:
- bool: If exists and has the same hash, return True otherwise False
- """
- key = self.__get_cache_key(model_file_info)
- is_exists = False
- file_path = key['Path']
- cache_file_path = os.path.join(self.cache_root_location,
- model_file_info['Path'])
- for cached_key in self.cached_files:
- if cached_key['Path'] == file_path and (
- cached_key['Revision'].startswith(key['Revision'])
- or key['Revision'].startswith(cached_key['Revision'])):
- expected_hash = model_file_info[FILE_HASH]
- if expected_hash is not None and os.path.exists(
- cache_file_path):
- # compute hash only when enabled, otherwise just meet expectation by default
- if enable_default_hash_validation:
- cache_file_sha256 = compute_hash(cache_file_path)
- else:
- cache_file_sha256 = expected_hash
- if expected_hash == cache_file_sha256:
- is_exists = True
- break
- else:
- logger.info(
- f'File [{file_path}] exists in cache but with a mismatched hash, will re-download.'
- )
- if is_exists:
- if os.path.exists(cache_file_path):
- return True
- else:
- self.remove_key(
- model_file_info) # someone may manual delete the file
- return False
- def remove_if_exists(self, model_file_info):
- """We in cache, remove it.
- Args:
- model_file_info (ModelFileInfo): The model file information from server.
- """
- for cached_file in self.cached_files:
- if cached_file['Path'] == model_file_info['Path']:
- self.remove_key(cached_file)
- file_path = os.path.join(self.cache_root_location,
- cached_file['Path'])
- if os.path.exists(file_path):
- os.remove(file_path)
- break
- def put_file(self, model_file_info, model_file_location):
- """Put model on model_file_location to cache, the model first download to /tmp, and move to cache.
- Args:
- model_file_info (str): The file description returned by get_model_files.
- model_file_location (str): The location of the temporary file.
- Returns:
- str: The location of the cached file.
- """
- self.remove_if_exists(model_file_info) # backup old revision
- cache_key = self.__get_cache_key(model_file_info)
- cache_full_path = os.path.join(
- self.cache_root_location,
- cache_key['Path']) # Branch and Tag do not have same name.
- cache_file_dir = os.path.dirname(cache_full_path)
- if not os.path.exists(cache_file_dir):
- os.makedirs(cache_file_dir, exist_ok=True)
- # We can't make operation transaction
- move(model_file_location, cache_full_path)
- self.cached_files.append(cache_key)
- self.save_cached_files()
- return cache_full_path
|