| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- import subprocess
- from typing import List, Optional
- from modelscope.utils.logger import get_logger
- from ..utils.constant import MASTER_MODEL_BRANCH
- from .errors import GitError
- logger = get_logger()
- class Singleton(type):
- _instances = {}
- def __call__(cls, *args, **kwargs):
- if cls not in cls._instances:
- cls._instances[cls] = super(Singleton,
- cls).__call__(*args, **kwargs)
- return cls._instances[cls]
- class GitCommandWrapper(metaclass=Singleton):
- """Some git operation wrapper
- """
- default_git_path = 'git' # The default git command line
- def __init__(self, path: str = None):
- self.git_path = path or self.default_git_path
- def _run_git_command(self, *args) -> subprocess.CompletedProcess:
- """Run git command, if command return 0, return subprocess.response
- otherwise raise GitError, message is stdout and stderr.
- Args:
- args: List of command args.
- Raises:
- GitError: Exception with stdout and stderr.
- Returns:
- subprocess.CompletedProcess: the command response
- """
- logger.debug(' '.join(args))
- git_env = os.environ.copy()
- git_env['GIT_TERMINAL_PROMPT'] = '0'
- command = [self.git_path, *args]
- command = [item for item in command if item]
- response = subprocess.run(
- command,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- env=git_env,
- ) # compatible for python3.6
- try:
- response.check_returncode()
- return response
- except subprocess.CalledProcessError as error:
- std_out = response.stdout.decode('utf8')
- std_err = error.stderr.decode('utf8')
- if 'nothing to commit' in std_out:
- logger.info(
- 'Nothing to commit, your local repo is upto date with remote'
- )
- return response
- else:
- logger.error(
- 'Running git command: %s failed \n stdout: %s \n stderr: %s'
- % (command, std_out, std_err))
- raise GitError(std_err)
- def config_auth_token(self, repo_dir, auth_token):
- url = self.get_repo_remote_url(repo_dir)
- if '//oauth2' not in url:
- auth_url = self._add_token(auth_token, url)
- cmd_args = '-C %s remote set-url origin %s' % (repo_dir, auth_url)
- cmd_args = cmd_args.split(' ')
- rsp = self._run_git_command(*cmd_args)
- logger.debug(rsp.stdout.decode('utf8'))
- def _add_token(self, token: str, url: str):
- if token:
- if '//oauth2' not in url:
- url = url.replace('//', '//oauth2:%s@' % token)
- return url
- def remove_token_from_url(self, url: str):
- if url and '//oauth2' in url:
- start_index = url.find('oauth2')
- end_index = url.find('@')
- url = url[:start_index] + url[end_index + 1:]
- return url
- def is_lfs_installed(self):
- cmd = ['lfs', 'env']
- try:
- self._run_git_command(*cmd)
- return True
- except GitError:
- return False
- def git_lfs_install(self, repo_dir):
- cmd = ['-C', repo_dir, 'lfs', 'install']
- try:
- self._run_git_command(*cmd)
- return True
- except GitError:
- return False
- def clone(self,
- repo_base_dir: str,
- token: str,
- url: str,
- repo_name: str,
- branch: Optional[str] = None):
- """ git clone command wrapper.
- For public project, token can None, private repo, there must token.
- Args:
- repo_base_dir (str): The local base dir, the repository will be clone to local_dir/repo_name
- token (str): The git token, must be provided for private project.
- url (str): The remote url
- repo_name (str): The local repository path name.
- branch (str, optional): _description_. Defaults to None.
- Returns:
- The popen response.
- """
- url = self._add_token(token, url)
- if branch:
- clone_args = '-C %s clone %s %s --branch %s' % (repo_base_dir, url,
- repo_name, branch)
- else:
- clone_args = '-C %s clone %s' % (repo_base_dir, url)
- logger.debug(clone_args)
- clone_args = clone_args.split(' ')
- response = self._run_git_command(*clone_args)
- logger.debug(response.stdout.decode('utf8'))
- return response
- def add_user_info(self, repo_base_dir, repo_name):
- from modelscope.hub.api import ModelScopeConfig
- user_name, user_email = ModelScopeConfig.get_user_info()
- if user_name and user_email:
- # config user.name and user.email if exist
- config_user_name_args = '-C %s/%s config user.name %s' % (
- repo_base_dir, repo_name, user_name)
- response = self._run_git_command(*config_user_name_args.split(' '))
- logger.debug(response.stdout.decode('utf8'))
- config_user_email_args = '-C %s/%s config user.email %s' % (
- repo_base_dir, repo_name, user_email)
- response = self._run_git_command(
- *config_user_email_args.split(' '))
- logger.debug(response.stdout.decode('utf8'))
- def add(self,
- repo_dir: str,
- files: List[str] = list(),
- all_files: bool = False):
- if all_files:
- add_args = '-C %s add -A' % repo_dir
- elif len(files) > 0:
- files_str = ' '.join(files)
- add_args = '-C %s add %s' % (repo_dir, files_str)
- add_args = add_args.split(' ')
- rsp = self._run_git_command(*add_args)
- logger.debug(rsp.stdout.decode('utf8'))
- return rsp
- def commit(self, repo_dir: str, message: str):
- """Run git commit command
- Args:
- repo_dir (str): the repository directory.
- message (str): commit message.
- Returns:
- The command popen response.
- """
- commit_args = ['-C', '%s' % repo_dir, 'commit', '-m', "'%s'" % message]
- rsp = self._run_git_command(*commit_args)
- logger.info(rsp.stdout.decode('utf8'))
- return rsp
- def checkout(self, repo_dir: str, revision: str):
- cmds = ['-C', '%s' % repo_dir, 'checkout', '%s' % revision]
- return self._run_git_command(*cmds)
- def new_branch(self, repo_dir: str, revision: str):
- cmds = ['-C', '%s' % repo_dir, 'checkout', '-b', revision]
- return self._run_git_command(*cmds)
- def get_remote_branches(self, repo_dir: str):
- cmds = ['-C', '%s' % repo_dir, 'branch', '-r']
- rsp = self._run_git_command(*cmds)
- info = [
- line.strip()
- for line in rsp.stdout.decode('utf8').strip().split(os.linesep)
- ]
- if len(info) == 1:
- return ['/'.join(info[0].split('/')[1:])]
- else:
- return ['/'.join(line.split('/')[1:]) for line in info[1:]]
- def pull(self,
- repo_dir: str,
- remote: str = 'origin',
- branch: str = 'master'):
- cmds = ['-C', repo_dir, 'pull', remote, branch]
- return self._run_git_command(*cmds)
- def push(self,
- repo_dir: str,
- token: str,
- url: str,
- local_branch: str,
- remote_branch: str,
- force: bool = False):
- url = self._add_token(token, url)
- push_args = '-C %s push %s %s:%s' % (repo_dir, url, local_branch,
- remote_branch)
- if force:
- push_args += ' -f'
- push_args = push_args.split(' ')
- rsp = self._run_git_command(*push_args)
- logger.debug(rsp.stdout.decode('utf8'))
- return rsp
- def get_repo_remote_url(self, repo_dir: str):
- cmd_args = '-C %s config --get remote.origin.url' % repo_dir
- cmd_args = cmd_args.split(' ')
- rsp = self._run_git_command(*cmd_args)
- url = rsp.stdout.decode('utf8')
- return url.strip()
- def list_lfs_files(self, repo_dir: str):
- cmd_args = '-C %s lfs ls-files' % repo_dir
- cmd_args = cmd_args.split(' ')
- rsp = self._run_git_command(*cmd_args)
- out = rsp.stdout.decode('utf8').strip()
- files = []
- for line in out.split(os.linesep):
- files.append(line.split(' ')[-1])
- return files
- def tag(self,
- repo_dir: str,
- tag_name: str,
- message: str,
- ref: str = MASTER_MODEL_BRANCH):
- cmd_args = [
- '-C', repo_dir, 'tag', tag_name, '-m',
- '"%s"' % message, ref
- ]
- rsp = self._run_git_command(*cmd_args)
- logger.debug(rsp.stdout.decode('utf8'))
- return rsp
- def push_tag(self, repo_dir: str, tag_name):
- cmd_args = ['-C', repo_dir, 'push', 'origin', tag_name]
- rsp = self._run_git_command(*cmd_args)
- logger.debug(rsp.stdout.decode('utf8'))
- return rsp
|