check_model.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Dict, Optional, Union
  4. from urllib.parse import urlparse
  5. from modelscope.hub.api import HubApi, ModelScopeConfig
  6. from modelscope.hub.constants import FILE_HASH
  7. from modelscope.hub.git import GitCommandWrapper
  8. from modelscope.hub.utils.caching import ModelFileSystemCache
  9. from modelscope.hub.utils.utils import compute_hash
  10. from modelscope.utils.logger import get_logger
  11. logger = get_logger()
  12. def get_model_id_from_cache(model_root_path: str, ) -> str:
  13. model_cache = None
  14. # download with git
  15. if os.path.exists(os.path.join(model_root_path, '.git')):
  16. git_cmd_wrapper = GitCommandWrapper()
  17. git_url = git_cmd_wrapper.get_repo_remote_url(model_root_path)
  18. if git_url.endswith('.git'):
  19. git_url = git_url[:-4]
  20. u_parse = urlparse(git_url)
  21. model_id = u_parse.path[1:]
  22. else: # snapshot_download
  23. model_cache = ModelFileSystemCache(model_root_path)
  24. model_id = model_cache.get_model_id()
  25. return model_id
  26. def check_local_model_is_latest(
  27. model_root_path: str,
  28. user_agent: Optional[Union[Dict, str]] = None,
  29. ):
  30. """Check local model repo is latest.
  31. Check local model repo is same as hub latest version.
  32. """
  33. try:
  34. model_id = get_model_id_from_cache(model_root_path)
  35. model_id = model_id.replace('___', '.')
  36. # make headers
  37. headers = {
  38. 'user-agent':
  39. ModelScopeConfig.get_user_agent(user_agent=user_agent, )
  40. }
  41. cookies = ModelScopeConfig.get_cookies()
  42. snapshot_header = headers if 'CI_TEST' in os.environ else {
  43. **headers,
  44. **{
  45. 'Snapshot': 'True'
  46. }
  47. }
  48. _api = HubApi(timeout=20)
  49. try:
  50. _, revisions = _api.get_model_branches_and_tags(
  51. model_id=model_id, use_cookies=cookies)
  52. if len(revisions) > 0:
  53. latest_revision = revisions[0]
  54. else:
  55. latest_revision = 'master'
  56. except: # noqa: E722
  57. latest_revision = 'master'
  58. model_files = _api.get_model_files(
  59. model_id=model_id,
  60. revision=latest_revision,
  61. recursive=True,
  62. headers=snapshot_header,
  63. use_cookies=cookies,
  64. )
  65. model_cache = None
  66. # download via non-git method
  67. if not os.path.exists(os.path.join(model_root_path, '.git')):
  68. model_cache = ModelFileSystemCache(model_root_path)
  69. for model_file in model_files:
  70. if model_file['Type'] == 'tree':
  71. continue
  72. # check model_file updated
  73. if model_cache is not None:
  74. if model_cache.exists(model_file):
  75. continue
  76. else:
  77. logger.info(
  78. f'Model file {model_file["Name"]} is different from the latest version `{latest_revision}`,'
  79. f'This is because you are using an older version or the file is updated manually.'
  80. )
  81. break
  82. else:
  83. if FILE_HASH in model_file:
  84. local_file_hash = compute_hash(
  85. os.path.join(model_root_path, model_file['Path']))
  86. if local_file_hash == model_file[FILE_HASH]:
  87. continue
  88. else:
  89. logger.info(
  90. f'Model file {model_file["Name"]} is different from the latest version `{latest_revision}`,'
  91. f'This is because you are using an older version or the file is updated manually.'
  92. )
  93. break
  94. except: # noqa: E722
  95. pass # ignore
  96. def check_model_is_id(model_id: str, token: Optional[str] = None):
  97. if model_id is None or os.path.exists(model_id):
  98. return False
  99. else:
  100. _api = HubApi()
  101. _api.login(token)
  102. try:
  103. _api.get_model(model_id=model_id, )
  104. return True
  105. except Exception:
  106. return False