| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import os
- from argparse import ArgumentParser
- from modelscope.cli.base import CLICommand
- from modelscope.hub.api import HubApi
- from modelscope.hub.constants import DEFAULT_MAX_WORKERS
- from modelscope.hub.file_download import (dataset_file_download,
- model_file_download)
- from modelscope.hub.snapshot_download import (dataset_snapshot_download,
- snapshot_download)
- from modelscope.hub.utils.utils import convert_patterns
- from modelscope.utils.constant import DEFAULT_DATASET_REVISION
- def subparser_func(args):
- """ Function which will be called for a specific sub parser.
- """
- return DownloadCMD(args)
- class DownloadCMD(CLICommand):
- name = 'download'
- def __init__(self, args):
- self.args = args
- @staticmethod
- def define_args(parsers: ArgumentParser):
- """ define args for download command.
- """
- parser: ArgumentParser = parsers.add_parser(DownloadCMD.name)
- group = parser.add_mutually_exclusive_group()
- group.add_argument(
- '--model',
- type=str,
- help='The id of the model to be downloaded. For download, '
- 'the id of either a model or dataset must be provided.')
- group.add_argument(
- '--dataset',
- type=str,
- help='The id of the dataset to be downloaded. For download, '
- 'the id of either a model or dataset must be provided.')
- parser.add_argument(
- 'repo_id',
- type=str,
- nargs='?',
- default=None,
- help='Optional, '
- 'ID of the repo to download, It can also be set by --model or --dataset.'
- )
- parser.add_argument(
- '--repo-type',
- choices=['model', 'dataset'],
- default='model',
- help="Type of repo to download from (defaults to 'model').",
- )
- parser.add_argument(
- '--token',
- type=str,
- default=None,
- help='Optional. Access token to download controlled entities.')
- parser.add_argument(
- '--revision',
- type=str,
- default=None,
- help='Revision of the entity (e.g., model).')
- parser.add_argument(
- '--cache_dir',
- type=str,
- default=None,
- help='Cache directory to save entity (e.g., model).')
- parser.add_argument(
- '--local_dir',
- type=str,
- default=None,
- help='File will be downloaded to local location specified by'
- 'local_dir, in this case, cache_dir parameter will be ignored.')
- parser.add_argument(
- 'files',
- type=str,
- default=None,
- nargs='*',
- help='Specify relative path to the repository file(s) to download.'
- "(e.g 'tokenizer.json', 'onnx/decoder_model.onnx').")
- parser.add_argument(
- '--include',
- nargs='*',
- default=None,
- type=str,
- help='Glob patterns to match files to download.'
- 'Ignored if file is specified')
- parser.add_argument(
- '--exclude',
- nargs='*',
- type=str,
- default=None,
- help='Glob patterns to exclude from files to download.'
- 'Ignored if file is specified')
- parser.add_argument(
- '--max-workers',
- type=int,
- default=DEFAULT_MAX_WORKERS,
- help='The maximum number of workers to download files.')
- parser.set_defaults(func=subparser_func)
- def execute(self):
- if self.args.model or self.args.dataset:
- # the position argument of files will be put to repo_id.
- if self.args.repo_id is not None:
- if self.args.files:
- self.args.files.insert(0, self.args.repo_id)
- else:
- self.args.files = [self.args.repo_id]
- else:
- if self.args.repo_id is not None:
- if self.args.repo_type == 'model':
- self.args.model = self.args.repo_id
- elif self.args.repo_type == 'dataset':
- self.args.dataset = self.args.repo_id
- else:
- raise Exception('Not support repo-type: %s'
- % self.args.repo_type)
- if not self.args.model and not self.args.dataset:
- raise Exception('Model or dataset must be set.')
- cookies = None
- if self.args.token is not None:
- api = HubApi()
- cookies = api.get_cookies(access_token=self.args.token)
- if self.args.model:
- if len(self.args.files) == 1: # download single file
- model_file_download(
- self.args.model,
- self.args.files[0],
- cache_dir=self.args.cache_dir,
- local_dir=self.args.local_dir,
- revision=self.args.revision,
- cookies=cookies)
- elif len(
- self.args.files) > 1: # download specified multiple files.
- snapshot_download(
- self.args.model,
- revision=self.args.revision,
- cache_dir=self.args.cache_dir,
- local_dir=self.args.local_dir,
- allow_file_pattern=self.args.files,
- max_workers=self.args.max_workers,
- cookies=cookies)
- else: # download repo
- snapshot_download(
- self.args.model,
- revision=self.args.revision,
- cache_dir=self.args.cache_dir,
- local_dir=self.args.local_dir,
- allow_file_pattern=convert_patterns(self.args.include),
- ignore_file_pattern=convert_patterns(self.args.exclude),
- max_workers=self.args.max_workers,
- cookies=cookies)
- print(f'\nSuccessfully Downloaded from model {self.args.model}.\n')
- elif self.args.dataset:
- dataset_revision: str = self.args.revision if self.args.revision else DEFAULT_DATASET_REVISION
- if len(self.args.files) == 1: # download single file
- dataset_file_download(
- self.args.dataset,
- self.args.files[0],
- cache_dir=self.args.cache_dir,
- local_dir=self.args.local_dir,
- revision=dataset_revision,
- cookies=cookies)
- elif len(
- self.args.files) > 1: # download specified multiple files.
- dataset_snapshot_download(
- self.args.dataset,
- revision=dataset_revision,
- cache_dir=self.args.cache_dir,
- local_dir=self.args.local_dir,
- allow_file_pattern=self.args.files,
- max_workers=self.args.max_workers,
- cookies=cookies)
- else: # download repo
- dataset_snapshot_download(
- self.args.dataset,
- revision=dataset_revision,
- cache_dir=self.args.cache_dir,
- local_dir=self.args.local_dir,
- allow_file_pattern=convert_patterns(self.args.include),
- ignore_file_pattern=convert_patterns(self.args.exclude),
- max_workers=self.args.max_workers,
- cookies=cookies)
- print(
- f'\nSuccessfully Downloaded from dataset {self.args.dataset}.\n'
- )
- else:
- pass # noop
|