download.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from argparse import ArgumentParser
  4. from modelscope.cli.base import CLICommand
  5. from modelscope.hub.api import HubApi
  6. from modelscope.hub.constants import DEFAULT_MAX_WORKERS
  7. from modelscope.hub.file_download import (dataset_file_download,
  8. model_file_download)
  9. from modelscope.hub.snapshot_download import (dataset_snapshot_download,
  10. snapshot_download)
  11. from modelscope.hub.utils.utils import convert_patterns
  12. from modelscope.utils.constant import DEFAULT_DATASET_REVISION
  13. def subparser_func(args):
  14. """ Function which will be called for a specific sub parser.
  15. """
  16. return DownloadCMD(args)
  17. class DownloadCMD(CLICommand):
  18. name = 'download'
  19. def __init__(self, args):
  20. self.args = args
  21. @staticmethod
  22. def define_args(parsers: ArgumentParser):
  23. """ define args for download command.
  24. """
  25. parser: ArgumentParser = parsers.add_parser(DownloadCMD.name)
  26. group = parser.add_mutually_exclusive_group()
  27. group.add_argument(
  28. '--model',
  29. type=str,
  30. help='The id of the model to be downloaded. For download, '
  31. 'the id of either a model or dataset must be provided.')
  32. group.add_argument(
  33. '--dataset',
  34. type=str,
  35. help='The id of the dataset to be downloaded. For download, '
  36. 'the id of either a model or dataset must be provided.')
  37. parser.add_argument(
  38. 'repo_id',
  39. type=str,
  40. nargs='?',
  41. default=None,
  42. help='Optional, '
  43. 'ID of the repo to download, It can also be set by --model or --dataset.'
  44. )
  45. parser.add_argument(
  46. '--repo-type',
  47. choices=['model', 'dataset'],
  48. default='model',
  49. help="Type of repo to download from (defaults to 'model').",
  50. )
  51. parser.add_argument(
  52. '--token',
  53. type=str,
  54. default=None,
  55. help='Optional. Access token to download controlled entities.')
  56. parser.add_argument(
  57. '--revision',
  58. type=str,
  59. default=None,
  60. help='Revision of the entity (e.g., model).')
  61. parser.add_argument(
  62. '--cache_dir',
  63. type=str,
  64. default=None,
  65. help='Cache directory to save entity (e.g., model).')
  66. parser.add_argument(
  67. '--local_dir',
  68. type=str,
  69. default=None,
  70. help='File will be downloaded to local location specified by'
  71. 'local_dir, in this case, cache_dir parameter will be ignored.')
  72. parser.add_argument(
  73. 'files',
  74. type=str,
  75. default=None,
  76. nargs='*',
  77. help='Specify relative path to the repository file(s) to download.'
  78. "(e.g 'tokenizer.json', 'onnx/decoder_model.onnx').")
  79. parser.add_argument(
  80. '--include',
  81. nargs='*',
  82. default=None,
  83. type=str,
  84. help='Glob patterns to match files to download.'
  85. 'Ignored if file is specified')
  86. parser.add_argument(
  87. '--exclude',
  88. nargs='*',
  89. type=str,
  90. default=None,
  91. help='Glob patterns to exclude from files to download.'
  92. 'Ignored if file is specified')
  93. parser.add_argument(
  94. '--max-workers',
  95. type=int,
  96. default=DEFAULT_MAX_WORKERS,
  97. help='The maximum number of workers to download files.')
  98. parser.set_defaults(func=subparser_func)
  99. def execute(self):
  100. if self.args.model or self.args.dataset:
  101. # the position argument of files will be put to repo_id.
  102. if self.args.repo_id is not None:
  103. if self.args.files:
  104. self.args.files.insert(0, self.args.repo_id)
  105. else:
  106. self.args.files = [self.args.repo_id]
  107. else:
  108. if self.args.repo_id is not None:
  109. if self.args.repo_type == 'model':
  110. self.args.model = self.args.repo_id
  111. elif self.args.repo_type == 'dataset':
  112. self.args.dataset = self.args.repo_id
  113. else:
  114. raise Exception('Not support repo-type: %s'
  115. % self.args.repo_type)
  116. if not self.args.model and not self.args.dataset:
  117. raise Exception('Model or dataset must be set.')
  118. cookies = None
  119. if self.args.token is not None:
  120. api = HubApi()
  121. cookies = api.get_cookies(access_token=self.args.token)
  122. if self.args.model:
  123. if len(self.args.files) == 1: # download single file
  124. model_file_download(
  125. self.args.model,
  126. self.args.files[0],
  127. cache_dir=self.args.cache_dir,
  128. local_dir=self.args.local_dir,
  129. revision=self.args.revision,
  130. cookies=cookies)
  131. elif len(
  132. self.args.files) > 1: # download specified multiple files.
  133. snapshot_download(
  134. self.args.model,
  135. revision=self.args.revision,
  136. cache_dir=self.args.cache_dir,
  137. local_dir=self.args.local_dir,
  138. allow_file_pattern=self.args.files,
  139. max_workers=self.args.max_workers,
  140. cookies=cookies)
  141. else: # download repo
  142. snapshot_download(
  143. self.args.model,
  144. revision=self.args.revision,
  145. cache_dir=self.args.cache_dir,
  146. local_dir=self.args.local_dir,
  147. allow_file_pattern=convert_patterns(self.args.include),
  148. ignore_file_pattern=convert_patterns(self.args.exclude),
  149. max_workers=self.args.max_workers,
  150. cookies=cookies)
  151. print(f'\nSuccessfully Downloaded from model {self.args.model}.\n')
  152. elif self.args.dataset:
  153. dataset_revision: str = self.args.revision if self.args.revision else DEFAULT_DATASET_REVISION
  154. if len(self.args.files) == 1: # download single file
  155. dataset_file_download(
  156. self.args.dataset,
  157. self.args.files[0],
  158. cache_dir=self.args.cache_dir,
  159. local_dir=self.args.local_dir,
  160. revision=dataset_revision,
  161. cookies=cookies)
  162. elif len(
  163. self.args.files) > 1: # download specified multiple files.
  164. dataset_snapshot_download(
  165. self.args.dataset,
  166. revision=dataset_revision,
  167. cache_dir=self.args.cache_dir,
  168. local_dir=self.args.local_dir,
  169. allow_file_pattern=self.args.files,
  170. max_workers=self.args.max_workers,
  171. cookies=cookies)
  172. else: # download repo
  173. dataset_snapshot_download(
  174. self.args.dataset,
  175. revision=dataset_revision,
  176. cache_dir=self.args.cache_dir,
  177. local_dir=self.args.local_dir,
  178. allow_file_pattern=convert_patterns(self.args.include),
  179. ignore_file_pattern=convert_patterns(self.args.exclude),
  180. max_workers=self.args.max_workers,
  181. cookies=cookies)
  182. print(
  183. f'\nSuccessfully Downloaded from dataset {self.args.dataset}.\n'
  184. )
  185. else:
  186. pass # noop