push_to_hub.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import concurrent.futures
  3. import os
  4. import shutil
  5. import tempfile
  6. from multiprocessing import Manager, Process, Value
  7. from pathlib import Path
  8. from typing import List, Optional, Union
  9. import json
  10. from modelscope.hub.api import HubApi
  11. from modelscope.hub.constants import ModelVisibility
  12. from modelscope.utils.constant import DEFAULT_REPOSITORY_REVISION
  13. from modelscope.utils.logger import get_logger
  14. logger = get_logger()
  15. _executor = concurrent.futures.ProcessPoolExecutor(max_workers=8)
  16. _queues = dict()
  17. _flags = dict()
  18. _tasks = dict()
  19. _manager = None
  20. def _push_files_to_hub(
  21. path_or_fileobj: Union[str, Path],
  22. path_in_repo: str,
  23. repo_id: str,
  24. token: Union[str, bool, None] = None,
  25. revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
  26. commit_message: Optional[str] = None,
  27. commit_description: Optional[str] = None,
  28. ):
  29. """Push files to model hub incrementally
  30. This function if used for patch_hub, user is not recommended to call this.
  31. This function will be merged to push_to_hub in later sprints.
  32. """
  33. if not os.path.exists(path_or_fileobj):
  34. return
  35. from modelscope import HubApi
  36. api = HubApi()
  37. api.login(token)
  38. if not commit_message:
  39. commit_message = 'Updating files'
  40. if commit_description:
  41. commit_message = commit_message + '\n' + commit_description
  42. with tempfile.TemporaryDirectory() as temp_cache_dir:
  43. from modelscope.hub.repository import Repository
  44. repo = Repository(temp_cache_dir, repo_id, revision=revision)
  45. if path_in_repo:
  46. sub_folder = os.path.join(temp_cache_dir, path_in_repo)
  47. else:
  48. sub_folder = temp_cache_dir
  49. os.makedirs(sub_folder, exist_ok=True)
  50. if os.path.isfile(path_or_fileobj):
  51. dest_file = os.path.join(sub_folder,
  52. os.path.basename(path_or_fileobj))
  53. shutil.copyfile(path_or_fileobj, dest_file)
  54. else:
  55. shutil.copytree(path_or_fileobj, sub_folder, dirs_exist_ok=True)
  56. repo.push(commit_message)
  57. def _api_push_to_hub(repo_name,
  58. output_dir,
  59. token,
  60. private=True,
  61. commit_message='',
  62. tag=None,
  63. source_repo='',
  64. ignore_file_pattern=None,
  65. revision=DEFAULT_REPOSITORY_REVISION):
  66. try:
  67. api = HubApi()
  68. api.login(token)
  69. api.push_model(
  70. repo_name,
  71. output_dir,
  72. visibility=ModelVisibility.PUBLIC
  73. if not private else ModelVisibility.PRIVATE,
  74. chinese_name=repo_name,
  75. commit_message=commit_message,
  76. tag=tag,
  77. original_model_id=source_repo,
  78. ignore_file_pattern=ignore_file_pattern,
  79. revision=revision)
  80. commit_message = commit_message or 'No commit message'
  81. logger.info(
  82. f'Successfully upload the model to {repo_name} with message: {commit_message}'
  83. )
  84. return True
  85. except Exception as e:
  86. logger.error(
  87. f'Error happens when uploading model {repo_name} with message: {commit_message}: {e}'
  88. )
  89. return False
  90. def push_to_hub(repo_name,
  91. output_dir,
  92. token=None,
  93. private=True,
  94. retry=3,
  95. commit_message='',
  96. tag=None,
  97. source_repo='',
  98. ignore_file_pattern=None,
  99. revision=DEFAULT_REPOSITORY_REVISION):
  100. """
  101. Args:
  102. repo_name: The repo name for the modelhub repo
  103. output_dir: The local output_dir for the checkpoint
  104. token: The user api token, function will check the `MODELSCOPE_API_TOKEN` variable if this argument is None
  105. private: If is a private repo, default True
  106. retry: Retry times if something error in uploading, default 3
  107. commit_message: The commit message
  108. tag: The tag of this commit
  109. source_repo: The source repo (model id) which this model comes from
  110. ignore_file_pattern: The file pattern to be ignored in uploading.
  111. revision: The branch to commit to
  112. Returns:
  113. The boolean value to represent whether the model is uploaded.
  114. """
  115. if token is None:
  116. token = os.environ.get('MODELSCOPE_API_TOKEN')
  117. if ignore_file_pattern is None:
  118. ignore_file_pattern = os.environ.get('UPLOAD_IGNORE_FILE_PATTERN')
  119. assert repo_name is not None
  120. assert token is not None, 'Either pass in a token or to set `MODELSCOPE_API_TOKEN` in the environment variables.'
  121. assert os.path.isdir(output_dir)
  122. assert 'configuration.json' in os.listdir(output_dir) or 'configuration.yaml' in os.listdir(output_dir) \
  123. or 'configuration.yml' in os.listdir(output_dir)
  124. logger.info(
  125. f'Uploading {output_dir} to {repo_name} with message {commit_message}')
  126. for i in range(retry):
  127. if _api_push_to_hub(repo_name, output_dir, token, private,
  128. commit_message, tag, source_repo,
  129. ignore_file_pattern, revision):
  130. return True
  131. return False
  132. def push_to_hub_async(repo_name,
  133. output_dir,
  134. token=None,
  135. private=True,
  136. commit_message='',
  137. tag=None,
  138. source_repo='',
  139. ignore_file_pattern=None,
  140. revision=DEFAULT_REPOSITORY_REVISION):
  141. """
  142. Args:
  143. repo_name: The repo name for the modelhub repo
  144. output_dir: The local output_dir for the checkpoint
  145. token: The user api token, function will check the `MODELSCOPE_API_TOKEN` variable if this argument is None
  146. private: If is a private repo, default True
  147. commit_message: The commit message
  148. tag: The tag of this commit
  149. source_repo: The source repo (model id) which this model comes from
  150. ignore_file_pattern: The file pattern to be ignored in uploading
  151. revision: The branch to commit to
  152. Returns:
  153. A handler to check the result and the status
  154. """
  155. if token is None:
  156. token = os.environ.get('MODELSCOPE_API_TOKEN')
  157. if ignore_file_pattern is None:
  158. ignore_file_pattern = os.environ.get('UPLOAD_IGNORE_FILE_PATTERN')
  159. assert repo_name is not None
  160. assert token is not None, 'Either pass in a token or to set `MODELSCOPE_API_TOKEN` in the environment variables.'
  161. assert os.path.isdir(output_dir)
  162. assert 'configuration.json' in os.listdir(output_dir) or 'configuration.yaml' in os.listdir(output_dir) \
  163. or 'configuration.yml' in os.listdir(output_dir)
  164. logger.info(
  165. f'Uploading {output_dir} to {repo_name} with message {commit_message}')
  166. return _executor.submit(_api_push_to_hub, repo_name, output_dir, token,
  167. private, commit_message, tag, source_repo,
  168. ignore_file_pattern, revision)
  169. def submit_task(q, b):
  170. while True:
  171. b.value = False
  172. item = q.get()
  173. logger.info(item)
  174. b.value = True
  175. if not item.pop('done', False):
  176. delete_dir = item.pop('delete_dir', False)
  177. output_dir = item.get('output_dir')
  178. try:
  179. push_to_hub(**item)
  180. if delete_dir and os.path.exists(output_dir):
  181. shutil.rmtree(output_dir)
  182. except Exception as e:
  183. logger.error(e)
  184. else:
  185. break
  186. class UploadStrategy:
  187. cancel = 'cancel'
  188. wait = 'wait'
  189. def push_to_hub_in_queue(queue_name, strategy=UploadStrategy.cancel, **kwargs):
  190. assert queue_name is not None and len(
  191. queue_name) > 0, 'Please specify a valid queue name!'
  192. global _manager
  193. if _manager is None:
  194. _manager = Manager()
  195. if queue_name not in _queues:
  196. _queues[queue_name] = _manager.Queue()
  197. _flags[queue_name] = Value('b', False)
  198. process = Process(
  199. target=submit_task, args=(_queues[queue_name], _flags[queue_name]))
  200. process.start()
  201. _tasks[queue_name] = process
  202. queue = _queues[queue_name]
  203. flag: Value = _flags[queue_name]
  204. if kwargs.get('done', False):
  205. queue.put(kwargs)
  206. elif flag.value and strategy == UploadStrategy.cancel:
  207. logger.error(
  208. f'Another uploading is running, '
  209. f'this uploading with message {kwargs.get("commit_message")} will be canceled.'
  210. )
  211. else:
  212. queue.put(kwargs)
  213. def wait_for_done(queue_name):
  214. process: Process = _tasks.pop(queue_name, None)
  215. if process is None:
  216. return
  217. process.join()
  218. _queues.pop(queue_name)
  219. _flags.pop(queue_name)