download.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # coding=utf-8
  2. # Copyright 202-present, the HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Contains command to download files from the Hub with the CLI.
  16. Usage:
  17. hf download --help
  18. # Download file
  19. hf download gpt2 config.json
  20. # Download entire repo
  21. hf download fffiloni/zeroscope --repo-type=space --revision=refs/pr/78
  22. # Download repo with filters
  23. hf download gpt2 --include="*.safetensors"
  24. # Download with token
  25. hf download Wauplin/private-model --token=hf_***
  26. # Download quietly (no progress bar, no warnings, only the returned path)
  27. hf download gpt2 config.json --quiet
  28. # Download to local dir
  29. hf download gpt2 --local-dir=./models/gpt2
  30. """
  31. import warnings
  32. from argparse import Namespace, _SubParsersAction
  33. from typing import List, Optional
  34. from huggingface_hub import logging
  35. from huggingface_hub._snapshot_download import snapshot_download
  36. from huggingface_hub.commands import BaseHuggingfaceCLICommand
  37. from huggingface_hub.file_download import hf_hub_download
  38. from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
  39. logger = logging.get_logger(__name__)
  40. class DownloadCommand(BaseHuggingfaceCLICommand):
  41. @staticmethod
  42. def register_subcommand(parser: _SubParsersAction):
  43. download_parser = parser.add_parser("download", help="Download files from the Hub")
  44. download_parser.add_argument(
  45. "repo_id", type=str, help="ID of the repo to download from (e.g. `username/repo-name`)."
  46. )
  47. download_parser.add_argument(
  48. "filenames", type=str, nargs="*", help="Files to download (e.g. `config.json`, `data/metadata.jsonl`)."
  49. )
  50. download_parser.add_argument(
  51. "--repo-type",
  52. choices=["model", "dataset", "space"],
  53. default="model",
  54. help="Type of repo to download from (defaults to 'model').",
  55. )
  56. download_parser.add_argument(
  57. "--revision",
  58. type=str,
  59. help="An optional Git revision id which can be a branch name, a tag, or a commit hash.",
  60. )
  61. download_parser.add_argument(
  62. "--include", nargs="*", type=str, help="Glob patterns to match files to download."
  63. )
  64. download_parser.add_argument(
  65. "--exclude", nargs="*", type=str, help="Glob patterns to exclude from files to download."
  66. )
  67. download_parser.add_argument(
  68. "--cache-dir", type=str, help="Path to the directory where to save the downloaded files."
  69. )
  70. download_parser.add_argument(
  71. "--local-dir",
  72. type=str,
  73. help=(
  74. "If set, the downloaded file will be placed under this directory. Check out"
  75. " https://huggingface.co/docs/huggingface_hub/guides/download#download-files-to-local-folder for more"
  76. " details."
  77. ),
  78. )
  79. download_parser.add_argument(
  80. "--force-download",
  81. action="store_true",
  82. help="If True, the files will be downloaded even if they are already cached.",
  83. )
  84. download_parser.add_argument(
  85. "--token", type=str, help="A User Access Token generated from https://huggingface.co/settings/tokens"
  86. )
  87. download_parser.add_argument(
  88. "--quiet",
  89. action="store_true",
  90. help="If True, progress bars are disabled and only the path to the download files is printed.",
  91. )
  92. download_parser.add_argument(
  93. "--max-workers",
  94. type=int,
  95. default=8,
  96. help="Maximum number of workers to use for downloading files. Default is 8.",
  97. )
  98. download_parser.set_defaults(func=DownloadCommand)
  99. def __init__(self, args: Namespace) -> None:
  100. self.token = args.token
  101. self.repo_id: str = args.repo_id
  102. self.filenames: List[str] = args.filenames
  103. self.repo_type: str = args.repo_type
  104. self.revision: Optional[str] = args.revision
  105. self.include: Optional[List[str]] = args.include
  106. self.exclude: Optional[List[str]] = args.exclude
  107. self.cache_dir: Optional[str] = args.cache_dir
  108. self.local_dir: Optional[str] = args.local_dir
  109. self.force_download: bool = args.force_download
  110. self.quiet: bool = args.quiet
  111. self.max_workers: int = args.max_workers
  112. def run(self) -> None:
  113. if self.quiet:
  114. disable_progress_bars()
  115. with warnings.catch_warnings():
  116. warnings.simplefilter("ignore")
  117. print(self._download()) # Print path to downloaded files
  118. enable_progress_bars()
  119. else:
  120. logging.set_verbosity_info()
  121. print(self._download()) # Print path to downloaded files
  122. logging.set_verbosity_warning()
  123. def _download(self) -> str:
  124. # Warn user if patterns are ignored
  125. if len(self.filenames) > 0:
  126. if self.include is not None and len(self.include) > 0:
  127. warnings.warn("Ignoring `--include` since filenames have being explicitly set.")
  128. if self.exclude is not None and len(self.exclude) > 0:
  129. warnings.warn("Ignoring `--exclude` since filenames have being explicitly set.")
  130. # Single file to download: use `hf_hub_download`
  131. if len(self.filenames) == 1:
  132. return hf_hub_download(
  133. repo_id=self.repo_id,
  134. repo_type=self.repo_type,
  135. revision=self.revision,
  136. filename=self.filenames[0],
  137. cache_dir=self.cache_dir,
  138. force_download=self.force_download,
  139. token=self.token,
  140. local_dir=self.local_dir,
  141. library_name="huggingface-cli",
  142. )
  143. # Otherwise: use `snapshot_download` to ensure all files comes from same revision
  144. elif len(self.filenames) == 0:
  145. allow_patterns = self.include
  146. ignore_patterns = self.exclude
  147. else:
  148. allow_patterns = self.filenames
  149. ignore_patterns = None
  150. return snapshot_download(
  151. repo_id=self.repo_id,
  152. repo_type=self.repo_type,
  153. revision=self.revision,
  154. allow_patterns=allow_patterns,
  155. ignore_patterns=ignore_patterns,
  156. force_download=self.force_download,
  157. cache_dir=self.cache_dir,
  158. token=self.token,
  159. local_dir=self.local_dir,
  160. library_name="huggingface-cli",
  161. max_workers=self.max_workers,
  162. )