| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import fnmatch
- import os
- import re
- import uuid
- from contextlib import nullcontext
- from http.cookiejar import CookieJar
- from pathlib import Path
- from typing import Dict, List, Optional, Type, Union
- from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
- DEFAULT_MODEL_REVISION,
- INTRA_CLOUD_ACCELERATION,
- REPO_TYPE_DATASET, REPO_TYPE_MODEL,
- REPO_TYPE_SUPPORT)
- from modelscope.utils.file_utils import get_modelscope_cache_dir
- from modelscope.utils.logger import get_logger
- from modelscope.utils.thread_utils import thread_executor
- from .api import HubApi, ModelScopeConfig
- from .callback import ProgressCallback
- from .constants import DEFAULT_MAX_WORKERS
- from .errors import InvalidParameter
- from .file_download import (create_temporary_directory_and_cache,
- download_file, get_file_download_url)
- from .utils.caching import ModelFileSystemCache
- from .utils.utils import (get_model_masked_directory,
- model_id_to_group_owner_name, strtobool,
- weak_file_lock)
- logger = get_logger()
- def snapshot_download(
- model_id: str = None,
- revision: Optional[str] = None,
- cache_dir: Union[str, Path, None] = None,
- user_agent: Optional[Union[Dict, str]] = None,
- local_files_only: Optional[bool] = False,
- cookies: Optional[CookieJar] = None,
- ignore_file_pattern: Optional[Union[str, List[str]]] = None,
- allow_file_pattern: Optional[Union[str, List[str]]] = None,
- local_dir: Optional[str] = None,
- allow_patterns: Optional[Union[List[str], str]] = None,
- ignore_patterns: Optional[Union[List[str], str]] = None,
- max_workers: Optional[int] = None,
- repo_id: str = None,
- repo_type: Optional[str] = REPO_TYPE_MODEL,
- enable_file_lock: Optional[bool] = None,
- progress_callbacks: List[Type[ProgressCallback]] = None,
- ) -> str:
- """Download all files of a repo.
- Downloads a whole snapshot of a repo's files at the specified revision. This
- is useful when you want all files from a repo, because you don't know which
- ones you will need a priori. All files are nested inside a folder in order
- to keep their actual filename relative to that folder.
- An alternative would be to just clone a repo but this would require that the
- user always has git and git-lfs installed, and properly configured.
- Args:
- repo_id (str): A user or an organization name and a repo name separated by a `/`.
- model_id (str): A user or an organization name and a model name separated by a `/`.
- if `repo_id` is provided, `model_id` will be ignored.
- repo_type (str, optional): The type of the repo, either 'model' or 'dataset'.
- revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
- commit hash. NOTE: currently only branch and tag name is supported
- cache_dir (str, Path, optional): Path to the folder where cached files are stored, model will
- be save as cache_dir/model_id/THE_MODEL_FILES.
- user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
- local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
- local cached file if it exists.
- cookies (CookieJar, optional): The cookie of the request, default None.
- ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
- Any file pattern to be ignored in downloading, like exact file names or file extensions.
- allow_file_pattern (`str` or `List`, *optional*, default to `None`):
- Any file pattern to be downloading, like exact file names or file extensions.
- local_dir (str, optional): Specific local directory path to which the file will be downloaded.
- allow_patterns (`str` or `List`, *optional*, default to `None`):
- If provided, only files matching at least one pattern are downloaded, priority over allow_file_pattern.
- For hugging-face compatibility.
- ignore_patterns (`str` or `List`, *optional*, default to `None`):
- If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
- For hugging-face compatibility.
- max_workers (`int`): The maximum number of workers to download files, default 8.
- enable_file_lock (`bool`): Enable file lock, this is useful in multiprocessing downloading, default `True`.
- If you find something wrong with file lock and have a problem modifying your code,
- change `MODELSCOPE_HUB_FILE_LOCK` env to `false`.
- progress_callbacks (`List[Type[ProgressCallback]]`, **optional**, default to `None`):
- progress callbacks to track the download progress.
- Raises:
- ValueError: the value details.
- Returns:
- str: Local folder path (string) of repo snapshot
- Note:
- Raises the following errors:
- - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
- if `use_auth_token=True` and the token cannot be found.
- - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
- ETag cannot be determined.
- - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
- if some parameter value is invalid
- """
- repo_id = repo_id or model_id
- if not repo_id:
- raise ValueError('Please provide a valid model_id or repo_id')
- if repo_type not in REPO_TYPE_SUPPORT:
- raise ValueError(
- f'Invalid repo type: {repo_type}, only support: {REPO_TYPE_SUPPORT}'
- )
- max_workers = max_workers or DEFAULT_MAX_WORKERS
- if revision is None:
- revision = DEFAULT_DATASET_REVISION if repo_type == REPO_TYPE_DATASET else DEFAULT_MODEL_REVISION
- if enable_file_lock is None:
- enable_file_lock = strtobool(
- os.environ.get('MODELSCOPE_HUB_FILE_LOCK', 'true'))
- if enable_file_lock:
- system_cache = cache_dir if cache_dir is not None else get_modelscope_cache_dir(
- )
- os.makedirs(os.path.join(system_cache, '.lock'), exist_ok=True)
- lock_file = os.path.join(system_cache, '.lock',
- repo_id.replace('/', '___'))
- context = weak_file_lock(lock_file)
- else:
- context = nullcontext()
- with context:
- return _snapshot_download(
- repo_id,
- repo_type=repo_type,
- revision=revision,
- cache_dir=cache_dir,
- user_agent=user_agent,
- local_files_only=local_files_only,
- cookies=cookies,
- ignore_file_pattern=ignore_file_pattern,
- allow_file_pattern=allow_file_pattern,
- local_dir=local_dir,
- ignore_patterns=ignore_patterns,
- allow_patterns=allow_patterns,
- max_workers=max_workers,
- progress_callbacks=progress_callbacks)
- def dataset_snapshot_download(
- dataset_id: str,
- revision: Optional[str] = DEFAULT_DATASET_REVISION,
- cache_dir: Union[str, Path, None] = None,
- local_dir: Optional[str] = None,
- user_agent: Optional[Union[Dict, str]] = None,
- local_files_only: Optional[bool] = False,
- cookies: Optional[CookieJar] = None,
- ignore_file_pattern: Optional[Union[str, List[str]]] = None,
- allow_file_pattern: Optional[Union[str, List[str]]] = None,
- allow_patterns: Optional[Union[List[str], str]] = None,
- ignore_patterns: Optional[Union[List[str], str]] = None,
- enable_file_lock: Optional[bool] = None,
- max_workers: int = 8,
- ) -> str:
- """Download raw files of a dataset.
- Downloads all files at the specified revision. This
- is useful when you want all files from a dataset, because you don't know which
- ones you will need a priori. All files are nested inside a folder in order
- to keep their actual filename relative to that folder.
- An alternative would be to just clone a dataset but this would require that the
- user always has git and git-lfs installed, and properly configured.
- Args:
- dataset_id (str): A user or an organization name and a dataset name separated by a `/`.
- revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
- commit hash. NOTE: currently only branch and tag name is supported
- cache_dir (str, Path, optional): Path to the folder where cached files are stored, dataset will
- be save as cache_dir/dataset_id/THE_DATASET_FILES.
- local_dir (str, optional): Specific local directory path to which the file will be downloaded.
- user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
- local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
- local cached file if it exists.
- cookies (CookieJar, optional): The cookie of the request, default None.
- ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
- Any file pattern to be ignored in downloading, like exact file names or file extensions.
- Use regression is deprecated.
- allow_file_pattern (`str` or `List`, *optional*, default to `None`):
- Any file pattern to be downloading, like exact file names or file extensions.
- allow_patterns (`str` or `List`, *optional*, default to `None`):
- If provided, only files matching at least one pattern are downloaded, priority over allow_file_pattern.
- For hugging-face compatibility.
- ignore_patterns (`str` or `List`, *optional*, default to `None`):
- If provided, files matching any of the patterns are not downloaded, priority over ignore_file_pattern.
- For hugging-face compatibility.
- enable_file_lock (`bool`): Enable file lock, this is useful in multiprocessing downloading, default `True`.
- If you find something wrong with file lock and have a problem modifying your code,
- change `MODELSCOPE_HUB_FILE_LOCK` env to `false`.
- max_workers (`int`): The maximum number of workers to download files, default 8.
- Raises:
- ValueError: the value details.
- Returns:
- str: Local folder path (string) of repo snapshot
- Note:
- Raises the following errors:
- - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
- if `use_auth_token=True` and the token cannot be found.
- - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if
- ETag cannot be determined.
- - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
- if some parameter value is invalid
- """
- if enable_file_lock is None:
- enable_file_lock = strtobool(
- os.environ.get('MODELSCOPE_HUB_FILE_LOCK', 'true'))
- if enable_file_lock:
- system_cache = cache_dir if cache_dir is not None else get_modelscope_cache_dir(
- )
- os.makedirs(os.path.join(system_cache, '.lock'), exist_ok=True)
- lock_file = os.path.join(system_cache, '.lock',
- dataset_id.replace('/', '___'))
- context = weak_file_lock(lock_file)
- else:
- context = nullcontext()
- with context:
- return _snapshot_download(
- dataset_id,
- repo_type=REPO_TYPE_DATASET,
- revision=revision,
- cache_dir=cache_dir,
- user_agent=user_agent,
- local_files_only=local_files_only,
- cookies=cookies,
- ignore_file_pattern=ignore_file_pattern,
- allow_file_pattern=allow_file_pattern,
- local_dir=local_dir,
- ignore_patterns=ignore_patterns,
- allow_patterns=allow_patterns,
- max_workers=max_workers)
- def _snapshot_download(
- repo_id: str,
- *,
- repo_type: Optional[str] = None,
- revision: Optional[str] = DEFAULT_MODEL_REVISION,
- cache_dir: Union[str, Path, None] = None,
- user_agent: Optional[Union[Dict, str]] = None,
- local_files_only: Optional[bool] = False,
- cookies: Optional[CookieJar] = None,
- ignore_file_pattern: Optional[Union[str, List[str]]] = None,
- allow_file_pattern: Optional[Union[str, List[str]]] = None,
- local_dir: Optional[str] = None,
- allow_patterns: Optional[Union[List[str], str]] = None,
- ignore_patterns: Optional[Union[List[str], str]] = None,
- max_workers: int = 8,
- progress_callbacks: List[Type[ProgressCallback]] = None,
- ):
- if not repo_type:
- repo_type = REPO_TYPE_MODEL
- if repo_type not in REPO_TYPE_SUPPORT:
- raise InvalidParameter('Invalid repo type: %s, only support: %s' %
- (repo_type, REPO_TYPE_SUPPORT))
- temporary_cache_dir, cache = create_temporary_directory_and_cache(
- repo_id, local_dir=local_dir, cache_dir=cache_dir, repo_type=repo_type)
- system_cache = cache_dir if cache_dir is not None else get_modelscope_cache_dir(
- )
- if local_files_only:
- if len(cache.cached_files) == 0:
- raise ValueError(
- 'Cannot find the requested files in the cached path and outgoing'
- ' traffic has been disabled. To enable look-ups and downloads'
- " online, set 'local_files_only' to False.")
- logger.warning('We can not confirm the cached file is for revision: %s'
- % revision)
- return cache.get_root_location(
- ) # we can not confirm the cached file is for snapshot 'revision'
- else:
- # make headers
- headers = {
- 'user-agent':
- ModelScopeConfig.get_user_agent(user_agent=user_agent, ),
- 'snapshot-identifier': str(uuid.uuid4()),
- }
- if INTRA_CLOUD_ACCELERATION == 'true':
- region_id: str = (
- os.getenv('INTRA_CLOUD_ACCELERATION_REGION')
- or HubApi()._get_internal_acceleration_domain())
- if region_id:
- logger.info(
- f'Intra-cloud acceleration enabled for downloading from {repo_id}'
- )
- headers['x-aliyun-region-id'] = region_id
- _api = HubApi()
- endpoint = _api.get_endpoint_for_read(
- repo_id=repo_id, repo_type=repo_type)
- if cookies is None:
- cookies = ModelScopeConfig.get_cookies()
- if repo_type == REPO_TYPE_MODEL:
- if local_dir:
- directory = os.path.abspath(local_dir)
- elif cache_dir:
- directory = os.path.join(system_cache, *repo_id.split('/'))
- else:
- directory = os.path.join(system_cache, 'models',
- *repo_id.split('/'))
- print(
- f'Downloading Model from {endpoint} to directory: {directory}')
- revision_detail = _api.get_valid_revision_detail(
- repo_id, revision=revision, cookies=cookies, endpoint=endpoint)
- revision = revision_detail['Revision']
- # Add snapshot-ci-test for counting the ci test download
- if 'CI_TEST' in os.environ:
- snapshot_header = {**headers, **{'snapshot-ci-test': 'True'}}
- else:
- snapshot_header = {**headers, **{'Snapshot': 'True'}}
- if cache.cached_model_revision is not None:
- snapshot_header[
- 'cached_model_revision'] = cache.cached_model_revision
- repo_files = _api.get_model_files(
- model_id=repo_id,
- revision=revision,
- recursive=True,
- use_cookies=False if cookies is None else cookies,
- headers=snapshot_header,
- endpoint=endpoint)
- _download_file_lists(
- repo_files,
- cache,
- temporary_cache_dir,
- repo_id,
- _api,
- None,
- None,
- headers,
- repo_type=repo_type,
- revision=revision,
- cookies=cookies,
- ignore_file_pattern=ignore_file_pattern,
- allow_file_pattern=allow_file_pattern,
- ignore_patterns=ignore_patterns,
- allow_patterns=allow_patterns,
- max_workers=max_workers,
- endpoint=endpoint,
- progress_callbacks=progress_callbacks,
- )
- if '.' in repo_id:
- masked_directory = get_model_masked_directory(
- directory, repo_id)
- if os.path.exists(directory):
- logger.info(
- 'Target directory already exists, skipping creation.')
- else:
- logger.info(f'Creating symbolic link [{directory}].')
- try:
- os.symlink(
- os.path.abspath(masked_directory),
- directory,
- target_is_directory=True)
- except OSError:
- logger.warning(
- f'Failed to create symbolic link {directory} for {os.path.abspath(masked_directory)}.'
- )
- elif repo_type == REPO_TYPE_DATASET:
- if local_dir:
- directory = os.path.abspath(local_dir)
- elif cache_dir:
- directory = os.path.join(system_cache, *repo_id.split('/'))
- else:
- directory = os.path.join(system_cache, 'datasets',
- *repo_id.split('/'))
- print(f'Downloading Dataset to directory: {directory}')
- group_or_owner, name = model_id_to_group_owner_name(repo_id)
- revision_detail = revision or DEFAULT_DATASET_REVISION
- logger.info('Fetching dataset repo file list...')
- repo_files = fetch_repo_files(_api, repo_id, revision_detail,
- endpoint)
- if repo_files is None:
- logger.error(
- f'Failed to retrieve file list for dataset: {repo_id}')
- return None
- _download_file_lists(
- repo_files,
- cache,
- temporary_cache_dir,
- repo_id,
- _api,
- name,
- group_or_owner,
- headers,
- repo_type=repo_type,
- revision=revision,
- cookies=cookies,
- ignore_file_pattern=ignore_file_pattern,
- allow_file_pattern=allow_file_pattern,
- ignore_patterns=ignore_patterns,
- allow_patterns=allow_patterns,
- max_workers=max_workers,
- endpoint=endpoint,
- progress_callbacks=progress_callbacks,
- )
- cache.save_model_version(revision_info=revision_detail)
- cache_root_path = cache.get_root_location()
- return cache_root_path
- def fetch_repo_files(_api, repo_id, revision, endpoint):
- page_number = 1
- page_size = 150
- repo_files = []
- while True:
- try:
- dataset_files = _api.get_dataset_files(
- repo_id=repo_id,
- revision=revision,
- root_path='/',
- recursive=True,
- page_number=page_number,
- page_size=page_size,
- endpoint=endpoint)
- except Exception as e:
- logger.error(f'Error fetching dataset files: {e}')
- break
- repo_files.extend(dataset_files)
- if len(dataset_files) < page_size:
- break
- page_number += 1
- return repo_files
- def _is_valid_regex(pattern: str):
- try:
- re.compile(pattern)
- return True
- except BaseException:
- return False
- def _normalize_patterns(patterns: Union[str, List[str]]):
- if isinstance(patterns, str):
- patterns = [patterns]
- if patterns is not None:
- patterns = [
- item if not item.endswith('/') else item + '*' for item in patterns
- ]
- return patterns
- def _get_valid_regex_pattern(patterns: List[str]):
- if patterns is not None:
- regex_patterns = []
- for item in patterns:
- if _is_valid_regex(item):
- regex_patterns.append(item)
- return regex_patterns
- else:
- return None
- def _download_file_lists(
- repo_files: List[str],
- cache: ModelFileSystemCache,
- temporary_cache_dir: str,
- repo_id: str,
- api: HubApi,
- name: str,
- group_or_owner: str,
- headers,
- repo_type: Optional[str] = None,
- revision: Optional[str] = DEFAULT_MODEL_REVISION,
- cookies: Optional[CookieJar] = None,
- ignore_file_pattern: Optional[Union[str, List[str]]] = None,
- allow_file_pattern: Optional[Union[str, List[str]]] = None,
- allow_patterns: Optional[Union[List[str], str]] = None,
- ignore_patterns: Optional[Union[List[str], str]] = None,
- max_workers: int = 8,
- endpoint: Optional[str] = None,
- progress_callbacks: List[Type[ProgressCallback]] = None,
- ):
- ignore_patterns = _normalize_patterns(ignore_patterns)
- allow_patterns = _normalize_patterns(allow_patterns)
- ignore_file_pattern = _normalize_patterns(ignore_file_pattern)
- allow_file_pattern = _normalize_patterns(allow_file_pattern)
- # to compatible regex usage.
- ignore_regex_pattern = _get_valid_regex_pattern(ignore_file_pattern)
- filtered_repo_files = []
- for repo_file in repo_files:
- if repo_file['Type'] == 'tree':
- continue
- try:
- # processing patterns
- if ignore_patterns and any([
- fnmatch.fnmatch(repo_file['Path'], pattern)
- for pattern in ignore_patterns
- ]):
- continue
- if ignore_file_pattern and any([
- fnmatch.fnmatch(repo_file['Path'], pattern)
- for pattern in ignore_file_pattern
- ]):
- continue
- if ignore_regex_pattern and any([
- re.search(pattern, repo_file['Name']) is not None
- for pattern in ignore_regex_pattern
- ]): # noqa E501
- continue
- if allow_patterns is not None and allow_patterns:
- if not any(
- fnmatch.fnmatch(repo_file['Path'], pattern)
- for pattern in allow_patterns):
- continue
- if allow_file_pattern is not None and allow_file_pattern:
- if not any(
- fnmatch.fnmatch(repo_file['Path'], pattern)
- for pattern in allow_file_pattern):
- continue
- # check model_file is exist in cache, if existed, skip download
- if cache.exists(repo_file):
- file_name = os.path.basename(repo_file['Name'])
- logger.debug(
- f'File {file_name} already in cache with identical hash, skip downloading!'
- )
- continue
- except Exception as e:
- logger.warning('The file pattern is invalid : %s' % e)
- else:
- filtered_repo_files.append(repo_file)
- @thread_executor(max_workers=max_workers, disable_tqdm=False)
- def _download_single_file(repo_file):
- if repo_type == REPO_TYPE_MODEL:
- url = get_file_download_url(
- model_id=repo_id,
- file_path=repo_file['Path'],
- revision=revision,
- endpoint=endpoint)
- elif repo_type == REPO_TYPE_DATASET:
- url = api.get_dataset_file_url(
- file_name=repo_file['Path'],
- dataset_name=name,
- namespace=group_or_owner,
- revision=revision,
- endpoint=endpoint)
- else:
- raise InvalidParameter(
- f'Invalid repo type: {repo_type}, supported types: {REPO_TYPE_SUPPORT}'
- )
- download_file(
- url,
- repo_file,
- temporary_cache_dir,
- cache,
- headers,
- cookies,
- disable_tqdm=False,
- progress_callbacks=progress_callbacks,
- )
- if len(filtered_repo_files) > 0:
- logger.info(
- f'Got {len(filtered_repo_files)} files, start to download ...')
- _download_single_file(filtered_repo_files)
- logger.info(f"Download {repo_type} '{repo_id}' successfully.")
|