check_model.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Dict, Optional, Union
  4. from urllib.parse import urlparse
  5. from modelscope.hub.api import HubApi, ModelScopeConfig
  6. from modelscope.hub.constants import FILE_HASH
  7. from modelscope.hub.git import GitCommandWrapper
  8. from modelscope.hub.utils.caching import ModelFileSystemCache
  9. from modelscope.hub.utils.utils import compute_hash
  10. from modelscope.utils.logger import get_logger
  11. logger = get_logger()
  12. def get_model_id_from_cache(model_root_path: str, ) -> str:
  13. model_cache = None
  14. # download with git
  15. if os.path.exists(os.path.join(model_root_path, '.git')):
  16. git_cmd_wrapper = GitCommandWrapper()
  17. git_url = git_cmd_wrapper.get_repo_remote_url(model_root_path)
  18. if git_url.endswith('.git'):
  19. git_url = git_url[:-4]
  20. u_parse = urlparse(git_url)
  21. model_id = u_parse.path[1:]
  22. else: # snapshot_download
  23. model_cache = ModelFileSystemCache(model_root_path)
  24. model_id = model_cache.get_model_id()
  25. return model_id
  26. def check_local_model_is_latest(
  27. model_root_path: str,
  28. user_agent: Optional[Union[Dict, str]] = None,
  29. token: Optional[str] = None,
  30. ):
  31. """Check local model repo is latest.
  32. Check local model repo is same as hub latest version.
  33. """
  34. try:
  35. model_id = get_model_id_from_cache(model_root_path)
  36. model_id = model_id.replace('___', '.')
  37. # make headers
  38. headers = {
  39. 'user-agent':
  40. ModelScopeConfig.get_user_agent(user_agent=user_agent, )
  41. }
  42. _api = HubApi(timeout=20, token=token)
  43. cookies = _api.get_cookies()
  44. snapshot_header = headers if 'CI_TEST' in os.environ else {
  45. **headers,
  46. **{
  47. 'Snapshot': 'True'
  48. }
  49. }
  50. try:
  51. _, revisions = _api.get_model_branches_and_tags(
  52. model_id=model_id, use_cookies=cookies)
  53. if len(revisions) > 0:
  54. latest_revision = revisions[0]
  55. else:
  56. latest_revision = 'master'
  57. except: # noqa: E722
  58. latest_revision = 'master'
  59. model_files = _api.get_model_files(
  60. model_id=model_id,
  61. revision=latest_revision,
  62. recursive=True,
  63. headers=snapshot_header,
  64. use_cookies=cookies,
  65. )
  66. model_cache = None
  67. # download via non-git method
  68. if not os.path.exists(os.path.join(model_root_path, '.git')):
  69. model_cache = ModelFileSystemCache(model_root_path)
  70. for model_file in model_files:
  71. if model_file['Type'] == 'tree':
  72. continue
  73. # check model_file updated
  74. if model_cache is not None:
  75. if model_cache.exists(model_file):
  76. continue
  77. else:
  78. logger.info(
  79. f'Model file {model_file["Name"]} is different from the latest version `{latest_revision}`,'
  80. f'This is because you are using an older version or the file is updated manually.'
  81. )
  82. break
  83. else:
  84. if FILE_HASH in model_file:
  85. local_file_hash = compute_hash(
  86. os.path.join(model_root_path, model_file['Path']))
  87. if local_file_hash == model_file[FILE_HASH]:
  88. continue
  89. else:
  90. logger.info(
  91. f'Model file {model_file["Name"]} is different from the latest version `{latest_revision}`,'
  92. f'This is because you are using an older version or the file is updated manually.'
  93. )
  94. break
  95. except: # noqa: E722
  96. pass # ignore
  97. def check_model_is_id(model_id: str, token: Optional[str] = None):
  98. if model_id is None or os.path.exists(model_id):
  99. return False
  100. else:
  101. _api = HubApi()
  102. _api.login(token)
  103. try:
  104. _api.get_model(model_id=model_id, )
  105. return True
  106. except Exception:
  107. return False