| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- from typing import Dict, Optional, Union
- from urllib.parse import urlparse
- from modelscope.hub.api import HubApi, ModelScopeConfig
- from modelscope.hub.constants import FILE_HASH
- from modelscope.hub.git import GitCommandWrapper
- from modelscope.hub.utils.caching import ModelFileSystemCache
- from modelscope.hub.utils.utils import compute_hash
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- def get_model_id_from_cache(model_root_path: str, ) -> str:
- model_cache = None
- # download with git
- if os.path.exists(os.path.join(model_root_path, '.git')):
- git_cmd_wrapper = GitCommandWrapper()
- git_url = git_cmd_wrapper.get_repo_remote_url(model_root_path)
- if git_url.endswith('.git'):
- git_url = git_url[:-4]
- u_parse = urlparse(git_url)
- model_id = u_parse.path[1:]
- else: # snapshot_download
- model_cache = ModelFileSystemCache(model_root_path)
- model_id = model_cache.get_model_id()
- return model_id
- def check_local_model_is_latest(
- model_root_path: str,
- user_agent: Optional[Union[Dict, str]] = None,
- ):
- """Check local model repo is latest.
- Check local model repo is same as hub latest version.
- """
- try:
- model_id = get_model_id_from_cache(model_root_path)
- model_id = model_id.replace('___', '.')
- # make headers
- headers = {
- 'user-agent':
- ModelScopeConfig.get_user_agent(user_agent=user_agent, )
- }
- cookies = ModelScopeConfig.get_cookies()
- snapshot_header = headers if 'CI_TEST' in os.environ else {
- **headers,
- **{
- 'Snapshot': 'True'
- }
- }
- _api = HubApi(timeout=20)
- try:
- _, revisions = _api.get_model_branches_and_tags(
- model_id=model_id, use_cookies=cookies)
- if len(revisions) > 0:
- latest_revision = revisions[0]
- else:
- latest_revision = 'master'
- except: # noqa: E722
- latest_revision = 'master'
- model_files = _api.get_model_files(
- model_id=model_id,
- revision=latest_revision,
- recursive=True,
- headers=snapshot_header,
- use_cookies=cookies,
- )
- model_cache = None
- # download via non-git method
- if not os.path.exists(os.path.join(model_root_path, '.git')):
- model_cache = ModelFileSystemCache(model_root_path)
- for model_file in model_files:
- if model_file['Type'] == 'tree':
- continue
- # check model_file updated
- if model_cache is not None:
- if model_cache.exists(model_file):
- continue
- else:
- logger.info(
- f'Model file {model_file["Name"]} is different from the latest version `{latest_revision}`,'
- f'This is because you are using an older version or the file is updated manually.'
- )
- break
- else:
- if FILE_HASH in model_file:
- local_file_hash = compute_hash(
- os.path.join(model_root_path, model_file['Path']))
- if local_file_hash == model_file[FILE_HASH]:
- continue
- else:
- logger.info(
- f'Model file {model_file["Name"]} is different from the latest version `{latest_revision}`,'
- f'This is because you are using an older version or the file is updated manually.'
- )
- break
- except: # noqa: E722
- pass # ignore
- def check_model_is_id(model_id: str, token: Optional[str] = None):
- if model_id is None or os.path.exists(model_id):
- return False
- else:
- _api = HubApi()
- _api.login(token)
- try:
- _api.get_model(model_id=model_id, )
- return True
- except Exception:
- return False
|