repository.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import warnings
  4. from typing import Optional
  5. from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException
  6. from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
  7. DEFAULT_REPOSITORY_REVISION,
  8. MASTER_MODEL_BRANCH)
  9. from modelscope.utils.logger import get_logger
  10. from .git import GitCommandWrapper
  11. from .utils.utils import get_endpoint
  12. logger = get_logger()
  13. class Repository:
  14. """A local representation of the model git repository.
  15. """
  16. def __init__(self,
  17. model_dir: str,
  18. clone_from: str,
  19. revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
  20. auth_token: Optional[str] = None,
  21. git_path: Optional[str] = None):
  22. """Instantiate a Repository object by cloning the remote ModelScopeHub repo
  23. Args:
  24. model_dir (str): The model root directory.
  25. clone_from (str): model id in ModelScope-hub from which git clone
  26. revision (str, optional): revision of the model you want to clone from.
  27. Can be any of a branch, tag or commit hash
  28. auth_token (str, optional): token obtained when calling `HubApi.login()`.
  29. Usually you can safely ignore the parameter as the token is already
  30. saved when you login the first time, if None, we will use saved token.
  31. git_path (str, optional): The git command line path, if None, we use 'git'
  32. Raises:
  33. InvalidParameter: revision is None.
  34. """
  35. self.model_dir = model_dir
  36. self.model_base_dir = os.path.dirname(model_dir)
  37. self.model_repo_name = os.path.basename(model_dir)
  38. if not revision:
  39. err_msg = 'a non-default value of revision cannot be empty.'
  40. raise InvalidParameter(err_msg)
  41. from modelscope.hub.api import ModelScopeConfig
  42. if auth_token:
  43. self.auth_token = auth_token
  44. else:
  45. self.auth_token = ModelScopeConfig.get_token()
  46. git_wrapper = GitCommandWrapper()
  47. if not git_wrapper.is_lfs_installed():
  48. logger.error('git lfs is not installed, please install.')
  49. self.git_wrapper = GitCommandWrapper(git_path)
  50. os.makedirs(self.model_dir, exist_ok=True)
  51. url = self._get_model_id_url(clone_from)
  52. if os.listdir(self.model_dir): # directory not empty.
  53. remote_url = self._get_remote_url()
  54. remote_url = self.git_wrapper.remove_token_from_url(remote_url)
  55. if remote_url and remote_url == url: # need not clone again
  56. return
  57. self.git_wrapper.clone(self.model_base_dir, self.auth_token, url,
  58. self.model_repo_name, revision)
  59. if git_wrapper.is_lfs_installed():
  60. git_wrapper.git_lfs_install(self.model_dir) # init repo lfs
  61. # add user info if login
  62. self.git_wrapper.add_user_info(self.model_base_dir,
  63. self.model_repo_name)
  64. if self.auth_token: # config remote with auth token
  65. self.git_wrapper.config_auth_token(self.model_dir, self.auth_token)
  66. def _get_model_id_url(self, model_id):
  67. url = f'{get_endpoint()}/{model_id}.git'
  68. return url
  69. def _get_remote_url(self):
  70. try:
  71. remote = self.git_wrapper.get_repo_remote_url(self.model_dir)
  72. except GitError:
  73. remote = None
  74. return remote
  75. def pull(self, remote: str = 'origin', branch: str = 'master'):
  76. """Pull remote branch
  77. Args:
  78. remote (str, optional): The remote name. Defaults to 'origin'.
  79. branch (str, optional): The remote branch. Defaults to 'master'.
  80. """
  81. self.git_wrapper.pull(self.model_dir, remote=remote, branch=branch)
  82. def add_lfs_type(self, file_name_suffix: str):
  83. """Add file suffix to lfs list.
  84. Args:
  85. file_name_suffix (str): The file name suffix.
  86. examples '*.safetensors'
  87. """
  88. os.system(
  89. "printf '\n%s filter=lfs diff=lfs merge=lfs -text\n'>>%s" %
  90. (file_name_suffix, os.path.join(self.model_dir, '.gitattributes')))
  91. def push(self,
  92. commit_message: str,
  93. local_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION,
  94. remote_branch: Optional[str] = DEFAULT_REPOSITORY_REVISION,
  95. force: Optional[bool] = False):
  96. warnings.warn(
  97. 'This function is deprecated and will be removed in future versions. '
  98. 'Please use git command directly or use HubApi().upload_folder instead',
  99. DeprecationWarning,
  100. stacklevel=2)
  101. """Push local files to remote, this method will do.
  102. Execute git pull, git add, git commit, git push in order.
  103. Args:
  104. commit_message (str): commit message
  105. local_branch(str, optional): The local branch, default master.
  106. remote_branch (str, optional): The remote branch to push, default master.
  107. force (bool, optional): whether to use forced-push.
  108. Raises:
  109. InvalidParameter: no commit message.
  110. NotLoginException: no auth token.
  111. """
  112. if commit_message is None or not isinstance(commit_message, str):
  113. msg = 'commit_message must be provided!'
  114. raise InvalidParameter(msg)
  115. if not isinstance(force, bool):
  116. raise InvalidParameter('force must be bool')
  117. if not self.auth_token:
  118. raise NotLoginException('Must login to push, please login first.')
  119. self.git_wrapper.config_auth_token(self.model_dir, self.auth_token)
  120. self.git_wrapper.add_user_info(self.model_base_dir,
  121. self.model_repo_name)
  122. url = self.git_wrapper.get_repo_remote_url(self.model_dir)
  123. self.git_wrapper.add(self.model_dir, all_files=True)
  124. self.git_wrapper.commit(self.model_dir, commit_message)
  125. self.git_wrapper.push(
  126. repo_dir=self.model_dir,
  127. token=self.auth_token,
  128. url=url,
  129. local_branch=local_branch,
  130. remote_branch=remote_branch)
  131. def tag(self,
  132. tag_name: str,
  133. message: str,
  134. ref: Optional[str] = MASTER_MODEL_BRANCH):
  135. """Create a new tag.
  136. Args:
  137. tag_name (str): The name of the tag
  138. message (str): The tag message.
  139. ref (str, optional): The tag reference, can be commit id or branch.
  140. Raises:
  141. InvalidParameter: no commit message.
  142. """
  143. if tag_name is None or tag_name == '':
  144. msg = 'We use tag-based revision, therefore tag_name cannot be None or empty.'
  145. raise InvalidParameter(msg)
  146. if message is None or message == '':
  147. msg = 'We use annotated tag, therefore message cannot None or empty.'
  148. raise InvalidParameter(msg)
  149. self.git_wrapper.tag(
  150. repo_dir=self.model_dir,
  151. tag_name=tag_name,
  152. message=message,
  153. ref=ref)
  154. def tag_and_push(self,
  155. tag_name: str,
  156. message: str,
  157. ref: Optional[str] = MASTER_MODEL_BRANCH):
  158. """Create tag and push to remote
  159. Args:
  160. tag_name (str): The name of the tag
  161. message (str): The tag message.
  162. ref (str, optional): The tag ref, can be commit id or branch. Defaults to MASTER_MODEL_BRANCH.
  163. """
  164. self.tag(tag_name, message, ref)
  165. self.git_wrapper.push_tag(repo_dir=self.model_dir, tag_name=tag_name)
  166. class DatasetRepository:
  167. """A local representation of the dataset (metadata) git repository.
  168. """
  169. def __init__(self,
  170. repo_work_dir: str,
  171. dataset_id: str,
  172. revision: Optional[str] = DEFAULT_DATASET_REVISION,
  173. auth_token: Optional[str] = None,
  174. git_path: Optional[str] = None):
  175. """
  176. Instantiate a Dataset Repository object by cloning the remote ModelScope dataset repo
  177. Args:
  178. repo_work_dir (str): The dataset repo root directory.
  179. dataset_id (str): dataset id in ModelScope from which git clone
  180. revision (str, optional): revision of the dataset you want to clone from.
  181. Can be any of a branch, tag or commit hash
  182. auth_token (str, optional): token obtained when calling `HubApi.login()`.
  183. Usually you can safely ignore the parameter as the token is
  184. already saved when you login the first time, if None, we will use saved token.
  185. git_path (str, optional): The git command line path, if None, we use 'git'
  186. Raises:
  187. InvalidParameter: parameter invalid.
  188. """
  189. self.dataset_id = dataset_id
  190. if not repo_work_dir or not isinstance(repo_work_dir, str):
  191. err_msg = 'dataset_work_dir must be provided!'
  192. raise InvalidParameter(err_msg)
  193. self.repo_work_dir = repo_work_dir.rstrip('/')
  194. if not self.repo_work_dir:
  195. err_msg = 'dataset_work_dir can not be root dir!'
  196. raise InvalidParameter(err_msg)
  197. self.repo_base_dir = os.path.dirname(self.repo_work_dir)
  198. self.repo_name = os.path.basename(self.repo_work_dir)
  199. if not revision:
  200. err_msg = 'a non-default value of revision cannot be empty.'
  201. raise InvalidParameter(err_msg)
  202. self.revision = revision
  203. from modelscope.hub.api import ModelScopeConfig
  204. if auth_token:
  205. self.auth_token = auth_token
  206. else:
  207. self.auth_token = ModelScopeConfig.get_token()
  208. self.git_wrapper = GitCommandWrapper(git_path)
  209. os.makedirs(self.repo_work_dir, exist_ok=True)
  210. self.repo_url = self._get_repo_url(dataset_id=dataset_id)
  211. def clone(self) -> str:
  212. # check local repo dir, directory not empty.
  213. if os.listdir(self.repo_work_dir):
  214. remote_url = self._get_remote_url()
  215. remote_url = self.git_wrapper.remove_token_from_url(remote_url)
  216. # no need clone again
  217. if remote_url and remote_url == self.repo_url:
  218. return ''
  219. logger.info('Cloning repo from {} '.format(self.repo_url))
  220. self.git_wrapper.clone(self.repo_base_dir, self.auth_token,
  221. self.repo_url, self.repo_name, self.revision)
  222. return self.repo_work_dir
  223. def push(self,
  224. commit_message: str,
  225. branch: Optional[str] = DEFAULT_DATASET_REVISION,
  226. force: Optional[bool] = False):
  227. warnings.warn(
  228. 'This function is deprecated and will be removed in future versions. '
  229. 'Please use git command directly or use HubApi().upload_folder instead',
  230. DeprecationWarning,
  231. stacklevel=2)
  232. """Push local files to remote, this method will do.
  233. git pull
  234. git add
  235. git commit
  236. git push
  237. Args:
  238. commit_message (str): commit message
  239. branch (str, optional): which branch to push.
  240. force (bool, optional): whether to use forced-push.
  241. Raises:
  242. InvalidParameter: no commit message.
  243. NotLoginException: no access token.
  244. """
  245. if commit_message is None or not isinstance(commit_message, str):
  246. msg = 'commit_message must be provided!'
  247. raise InvalidParameter(msg)
  248. if not isinstance(force, bool):
  249. raise InvalidParameter('force must be bool')
  250. if not self.auth_token:
  251. raise NotLoginException('Must login to push, please login first.')
  252. self.git_wrapper.config_auth_token(self.repo_work_dir, self.auth_token)
  253. self.git_wrapper.add_user_info(self.repo_base_dir, self.repo_name)
  254. remote_url = self._get_remote_url()
  255. remote_url = self.git_wrapper.remove_token_from_url(remote_url)
  256. self.git_wrapper.pull(self.repo_work_dir)
  257. self.git_wrapper.add(self.repo_work_dir, all_files=True)
  258. self.git_wrapper.commit(self.repo_work_dir, commit_message)
  259. self.git_wrapper.push(
  260. repo_dir=self.repo_work_dir,
  261. token=self.auth_token,
  262. url=remote_url,
  263. local_branch=branch,
  264. remote_branch=branch)
  265. def _get_repo_url(self, dataset_id):
  266. return f'{get_endpoint()}/datasets/{dataset_id}.git'
  267. def _get_remote_url(self):
  268. try:
  269. remote = self.git_wrapper.get_repo_remote_url(self.repo_work_dir)
  270. except GitError:
  271. remote = None
  272. return remote