# Copyright (c) Alibaba, Inc. and its affiliates. import os from typing import Dict, Optional, Union from urllib.parse import urlparse from modelscope.hub.api import HubApi, ModelScopeConfig from modelscope.hub.constants import FILE_HASH from modelscope.hub.git import GitCommandWrapper from modelscope.hub.utils.caching import ModelFileSystemCache from modelscope.hub.utils.utils import compute_hash from modelscope.utils.logger import get_logger logger = get_logger() def get_model_id_from_cache(model_root_path: str, ) -> str: model_cache = None # download with git if os.path.exists(os.path.join(model_root_path, '.git')): git_cmd_wrapper = GitCommandWrapper() git_url = git_cmd_wrapper.get_repo_remote_url(model_root_path) if git_url.endswith('.git'): git_url = git_url[:-4] u_parse = urlparse(git_url) model_id = u_parse.path[1:] else: # snapshot_download model_cache = ModelFileSystemCache(model_root_path) model_id = model_cache.get_model_id() return model_id def check_local_model_is_latest( model_root_path: str, user_agent: Optional[Union[Dict, str]] = None, ): """Check local model repo is latest. Check local model repo is same as hub latest version. """ try: model_id = get_model_id_from_cache(model_root_path) model_id = model_id.replace('___', '.') # make headers headers = { 'user-agent': ModelScopeConfig.get_user_agent(user_agent=user_agent, ) } cookies = ModelScopeConfig.get_cookies() snapshot_header = headers if 'CI_TEST' in os.environ else { **headers, **{ 'Snapshot': 'True' } } _api = HubApi(timeout=20) try: _, revisions = _api.get_model_branches_and_tags( model_id=model_id, use_cookies=cookies) if len(revisions) > 0: latest_revision = revisions[0] else: latest_revision = 'master' except: # noqa: E722 latest_revision = 'master' model_files = _api.get_model_files( model_id=model_id, revision=latest_revision, recursive=True, headers=snapshot_header, use_cookies=cookies, ) model_cache = None # download via non-git method if not os.path.exists(os.path.join(model_root_path, '.git')): model_cache = ModelFileSystemCache(model_root_path) for model_file in model_files: if model_file['Type'] == 'tree': continue # check model_file updated if model_cache is not None: if model_cache.exists(model_file): continue else: logger.info( f'Model file {model_file["Name"]} is different from the latest version `{latest_revision}`,' f'This is because you are using an older version or the file is updated manually.' ) break else: if FILE_HASH in model_file: local_file_hash = compute_hash( os.path.join(model_root_path, model_file['Path'])) if local_file_hash == model_file[FILE_HASH]: continue else: logger.info( f'Model file {model_file["Name"]} is different from the latest version `{latest_revision}`,' f'This is because you are using an older version or the file is updated manually.' ) break except: # noqa: E722 pass # ignore def check_model_is_id(model_id: str, token: Optional[str] = None): if model_id is None or os.path.exists(model_id): return False else: _api = HubApi() _api.login(token) try: _api.get_model(model_id=model_id, ) return True except Exception: return False