_commit_api.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968
  1. """
  2. Type definitions and utilities for the `create_commit` API
  3. """
  4. import base64
  5. import io
  6. import os
  7. import warnings
  8. from collections import defaultdict
  9. from contextlib import contextmanager
  10. from dataclasses import dataclass, field
  11. from itertools import groupby
  12. from pathlib import Path, PurePosixPath
  13. from typing import TYPE_CHECKING, Any, BinaryIO, Dict, Iterable, Iterator, List, Literal, Optional, Tuple, Union
  14. from tqdm.contrib.concurrent import thread_map
  15. from . import constants
  16. from .errors import EntryNotFoundError, HfHubHTTPError, XetAuthorizationError, XetRefreshTokenError
  17. from .file_download import hf_hub_url
  18. from .lfs import UploadInfo, lfs_upload, post_lfs_batch_info
  19. from .utils import (
  20. FORBIDDEN_FOLDERS,
  21. XetTokenType,
  22. are_progress_bars_disabled,
  23. chunk_iterable,
  24. fetch_xet_connection_info_from_repo_info,
  25. get_session,
  26. hf_raise_for_status,
  27. logging,
  28. sha,
  29. tqdm_stream_file,
  30. validate_hf_hub_args,
  31. )
  32. from .utils import tqdm as hf_tqdm
  33. from .utils._runtime import is_xet_available
  34. if TYPE_CHECKING:
  35. from .hf_api import RepoFile
  36. logger = logging.get_logger(__name__)
  37. UploadMode = Literal["lfs", "regular"]
  38. # Max is 1,000 per request on the Hub for HfApi.get_paths_info
  39. # Otherwise we get:
  40. # HfHubHTTPError: 413 Client Error: Payload Too Large for url: https://huggingface.co/api/datasets/xxx (Request ID: xxx)\n\ntoo many parameters
  41. # See https://github.com/huggingface/huggingface_hub/issues/1503
  42. FETCH_LFS_BATCH_SIZE = 500
  43. UPLOAD_BATCH_MAX_NUM_FILES = 256
  44. @dataclass
  45. class CommitOperationDelete:
  46. """
  47. Data structure holding necessary info to delete a file or a folder from a repository
  48. on the Hub.
  49. Args:
  50. path_in_repo (`str`):
  51. Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"`
  52. for a file or `"checkpoints/1fec34a/"` for a folder.
  53. is_folder (`bool` or `Literal["auto"]`, *optional*)
  54. Whether the Delete Operation applies to a folder or not. If "auto", the path
  55. type (file or folder) is guessed automatically by looking if path ends with
  56. a "/" (folder) or not (file). To explicitly set the path type, you can set
  57. `is_folder=True` or `is_folder=False`.
  58. """
  59. path_in_repo: str
  60. is_folder: Union[bool, Literal["auto"]] = "auto"
  61. def __post_init__(self):
  62. self.path_in_repo = _validate_path_in_repo(self.path_in_repo)
  63. if self.is_folder == "auto":
  64. self.is_folder = self.path_in_repo.endswith("/")
  65. if not isinstance(self.is_folder, bool):
  66. raise ValueError(
  67. f"Wrong value for `is_folder`. Must be one of [`True`, `False`, `'auto'`]. Got '{self.is_folder}'."
  68. )
  69. @dataclass
  70. class CommitOperationCopy:
  71. """
  72. Data structure holding necessary info to copy a file in a repository on the Hub.
  73. Limitations:
  74. - Only LFS files can be copied. To copy a regular file, you need to download it locally and re-upload it
  75. - Cross-repository copies are not supported.
  76. Note: you can combine a [`CommitOperationCopy`] and a [`CommitOperationDelete`] to rename an LFS file on the Hub.
  77. Args:
  78. src_path_in_repo (`str`):
  79. Relative filepath in the repo of the file to be copied, e.g. `"checkpoints/1fec34a/weights.bin"`.
  80. path_in_repo (`str`):
  81. Relative filepath in the repo where to copy the file, e.g. `"checkpoints/1fec34a/weights_copy.bin"`.
  82. src_revision (`str`, *optional*):
  83. The git revision of the file to be copied. Can be any valid git revision.
  84. Default to the target commit revision.
  85. """
  86. src_path_in_repo: str
  87. path_in_repo: str
  88. src_revision: Optional[str] = None
  89. # set to the OID of the file to be copied if it has already been uploaded
  90. # useful to determine if a commit will be empty or not.
  91. _src_oid: Optional[str] = None
  92. # set to the OID of the file to copy to if it has already been uploaded
  93. # useful to determine if a commit will be empty or not.
  94. _dest_oid: Optional[str] = None
  95. def __post_init__(self):
  96. self.src_path_in_repo = _validate_path_in_repo(self.src_path_in_repo)
  97. self.path_in_repo = _validate_path_in_repo(self.path_in_repo)
  98. @dataclass
  99. class CommitOperationAdd:
  100. """
  101. Data structure holding necessary info to upload a file to a repository on the Hub.
  102. Args:
  103. path_in_repo (`str`):
  104. Relative filepath in the repo, for example: `"checkpoints/1fec34a/weights.bin"`
  105. path_or_fileobj (`str`, `Path`, `bytes`, or `BinaryIO`):
  106. Either:
  107. - a path to a local file (as `str` or `pathlib.Path`) to upload
  108. - a buffer of bytes (`bytes`) holding the content of the file to upload
  109. - a "file object" (subclass of `io.BufferedIOBase`), typically obtained
  110. with `open(path, "rb")`. It must support `seek()` and `tell()` methods.
  111. Raises:
  112. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  113. If `path_or_fileobj` is not one of `str`, `Path`, `bytes` or `io.BufferedIOBase`.
  114. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  115. If `path_or_fileobj` is a `str` or `Path` but not a path to an existing file.
  116. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  117. If `path_or_fileobj` is a `io.BufferedIOBase` but it doesn't support both
  118. `seek()` and `tell()`.
  119. """
  120. path_in_repo: str
  121. path_or_fileobj: Union[str, Path, bytes, BinaryIO]
  122. upload_info: UploadInfo = field(init=False, repr=False)
  123. # Internal attributes
  124. # set to "lfs" or "regular" once known
  125. _upload_mode: Optional[UploadMode] = field(init=False, repr=False, default=None)
  126. # set to True if .gitignore rules prevent the file from being uploaded as LFS
  127. # (server-side check)
  128. _should_ignore: Optional[bool] = field(init=False, repr=False, default=None)
  129. # set to the remote OID of the file if it has already been uploaded
  130. # useful to determine if a commit will be empty or not
  131. _remote_oid: Optional[str] = field(init=False, repr=False, default=None)
  132. # set to True once the file has been uploaded as LFS
  133. _is_uploaded: bool = field(init=False, repr=False, default=False)
  134. # set to True once the file has been committed
  135. _is_committed: bool = field(init=False, repr=False, default=False)
  136. def __post_init__(self) -> None:
  137. """Validates `path_or_fileobj` and compute `upload_info`."""
  138. self.path_in_repo = _validate_path_in_repo(self.path_in_repo)
  139. # Validate `path_or_fileobj` value
  140. if isinstance(self.path_or_fileobj, Path):
  141. self.path_or_fileobj = str(self.path_or_fileobj)
  142. if isinstance(self.path_or_fileobj, str):
  143. path_or_fileobj = os.path.normpath(os.path.expanduser(self.path_or_fileobj))
  144. if not os.path.isfile(path_or_fileobj):
  145. raise ValueError(f"Provided path: '{path_or_fileobj}' is not a file on the local file system")
  146. elif not isinstance(self.path_or_fileobj, (io.BufferedIOBase, bytes)):
  147. # ^^ Inspired from: https://stackoverflow.com/questions/44584829/how-to-determine-if-file-is-opened-in-binary-or-text-mode
  148. raise ValueError(
  149. "path_or_fileobj must be either an instance of str, bytes or"
  150. " io.BufferedIOBase. If you passed a file-like object, make sure it is"
  151. " in binary mode."
  152. )
  153. if isinstance(self.path_or_fileobj, io.BufferedIOBase):
  154. try:
  155. self.path_or_fileobj.tell()
  156. self.path_or_fileobj.seek(0, os.SEEK_CUR)
  157. except (OSError, AttributeError) as exc:
  158. raise ValueError(
  159. "path_or_fileobj is a file-like object but does not implement seek() and tell()"
  160. ) from exc
  161. # Compute "upload_info" attribute
  162. if isinstance(self.path_or_fileobj, str):
  163. self.upload_info = UploadInfo.from_path(self.path_or_fileobj)
  164. elif isinstance(self.path_or_fileobj, bytes):
  165. self.upload_info = UploadInfo.from_bytes(self.path_or_fileobj)
  166. else:
  167. self.upload_info = UploadInfo.from_fileobj(self.path_or_fileobj)
  168. @contextmanager
  169. def as_file(self, with_tqdm: bool = False) -> Iterator[BinaryIO]:
  170. """
  171. A context manager that yields a file-like object allowing to read the underlying
  172. data behind `path_or_fileobj`.
  173. Args:
  174. with_tqdm (`bool`, *optional*, defaults to `False`):
  175. If True, iterating over the file object will display a progress bar. Only
  176. works if the file-like object is a path to a file. Pure bytes and buffers
  177. are not supported.
  178. Example:
  179. ```python
  180. >>> operation = CommitOperationAdd(
  181. ... path_in_repo="remote/dir/weights.h5",
  182. ... path_or_fileobj="./local/weights.h5",
  183. ... )
  184. CommitOperationAdd(path_in_repo='remote/dir/weights.h5', path_or_fileobj='./local/weights.h5')
  185. >>> with operation.as_file() as file:
  186. ... content = file.read()
  187. >>> with operation.as_file(with_tqdm=True) as file:
  188. ... while True:
  189. ... data = file.read(1024)
  190. ... if not data:
  191. ... break
  192. config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s]
  193. >>> with operation.as_file(with_tqdm=True) as file:
  194. ... requests.put(..., data=file)
  195. config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s]
  196. ```
  197. """
  198. if isinstance(self.path_or_fileobj, str) or isinstance(self.path_or_fileobj, Path):
  199. if with_tqdm:
  200. with tqdm_stream_file(self.path_or_fileobj) as file:
  201. yield file
  202. else:
  203. with open(self.path_or_fileobj, "rb") as file:
  204. yield file
  205. elif isinstance(self.path_or_fileobj, bytes):
  206. yield io.BytesIO(self.path_or_fileobj)
  207. elif isinstance(self.path_or_fileobj, io.BufferedIOBase):
  208. prev_pos = self.path_or_fileobj.tell()
  209. yield self.path_or_fileobj
  210. self.path_or_fileobj.seek(prev_pos, io.SEEK_SET)
  211. def b64content(self) -> bytes:
  212. """
  213. The base64-encoded content of `path_or_fileobj`
  214. Returns: `bytes`
  215. """
  216. with self.as_file() as file:
  217. return base64.b64encode(file.read())
  218. @property
  219. def _local_oid(self) -> Optional[str]:
  220. """Return the OID of the local file.
  221. This OID is then compared to `self._remote_oid` to check if the file has changed compared to the remote one.
  222. If the file did not change, we won't upload it again to prevent empty commits.
  223. For LFS files, the OID corresponds to the SHA256 of the file content (used a LFS ref).
  224. For regular files, the OID corresponds to the SHA1 of the file content.
  225. Note: this is slightly different to git OID computation since the oid of an LFS file is usually the git-SHA1 of the
  226. pointer file content (not the actual file content). However, using the SHA256 is enough to detect changes
  227. and more convenient client-side.
  228. """
  229. if self._upload_mode is None:
  230. return None
  231. elif self._upload_mode == "lfs":
  232. return self.upload_info.sha256.hex()
  233. else:
  234. # Regular file => compute sha1
  235. # => no need to read by chunk since the file is guaranteed to be <=5MB.
  236. with self.as_file() as file:
  237. return sha.git_hash(file.read())
  238. def _validate_path_in_repo(path_in_repo: str) -> str:
  239. # Validate `path_in_repo` value to prevent a server-side issue
  240. if path_in_repo.startswith("/"):
  241. path_in_repo = path_in_repo[1:]
  242. if path_in_repo == "." or path_in_repo == ".." or path_in_repo.startswith("../"):
  243. raise ValueError(f"Invalid `path_in_repo` in CommitOperation: '{path_in_repo}'")
  244. if path_in_repo.startswith("./"):
  245. path_in_repo = path_in_repo[2:]
  246. for forbidden in FORBIDDEN_FOLDERS:
  247. if any(part == forbidden for part in path_in_repo.split("/")):
  248. raise ValueError(
  249. f"Invalid `path_in_repo` in CommitOperation: cannot update files under a '{forbidden}/' folder (path:"
  250. f" '{path_in_repo}')."
  251. )
  252. return path_in_repo
  253. CommitOperation = Union[CommitOperationAdd, CommitOperationCopy, CommitOperationDelete]
  254. def _warn_on_overwriting_operations(operations: List[CommitOperation]) -> None:
  255. """
  256. Warn user when a list of operations is expected to overwrite itself in a single
  257. commit.
  258. Rules:
  259. - If a filepath is updated by multiple `CommitOperationAdd` operations, a warning
  260. message is triggered.
  261. - If a filepath is updated at least once by a `CommitOperationAdd` and then deleted
  262. by a `CommitOperationDelete`, a warning is triggered.
  263. - If a `CommitOperationDelete` deletes a filepath that is then updated by a
  264. `CommitOperationAdd`, no warning is triggered. This is usually useless (no need to
  265. delete before upload) but can happen if a user deletes an entire folder and then
  266. add new files to it.
  267. """
  268. nb_additions_per_path: Dict[str, int] = defaultdict(int)
  269. for operation in operations:
  270. path_in_repo = operation.path_in_repo
  271. if isinstance(operation, CommitOperationAdd):
  272. if nb_additions_per_path[path_in_repo] > 0:
  273. warnings.warn(
  274. "About to update multiple times the same file in the same commit:"
  275. f" '{path_in_repo}'. This can cause undesired inconsistencies in"
  276. " your repo."
  277. )
  278. nb_additions_per_path[path_in_repo] += 1
  279. for parent in PurePosixPath(path_in_repo).parents:
  280. # Also keep track of number of updated files per folder
  281. # => warns if deleting a folder overwrite some contained files
  282. nb_additions_per_path[str(parent)] += 1
  283. if isinstance(operation, CommitOperationDelete):
  284. if nb_additions_per_path[str(PurePosixPath(path_in_repo))] > 0:
  285. if operation.is_folder:
  286. warnings.warn(
  287. "About to delete a folder containing files that have just been"
  288. f" updated within the same commit: '{path_in_repo}'. This can"
  289. " cause undesired inconsistencies in your repo."
  290. )
  291. else:
  292. warnings.warn(
  293. "About to delete a file that have just been updated within the"
  294. f" same commit: '{path_in_repo}'. This can cause undesired"
  295. " inconsistencies in your repo."
  296. )
  297. @validate_hf_hub_args
  298. def _upload_files(
  299. *,
  300. additions: List[CommitOperationAdd],
  301. repo_type: str,
  302. repo_id: str,
  303. headers: Dict[str, str],
  304. endpoint: Optional[str] = None,
  305. num_threads: int = 5,
  306. revision: Optional[str] = None,
  307. create_pr: Optional[bool] = None,
  308. ):
  309. """
  310. Negotiates per-file transfer (LFS vs Xet) and uploads in batches.
  311. """
  312. xet_additions: List[CommitOperationAdd] = []
  313. lfs_actions: List[Dict] = []
  314. lfs_oid2addop: Dict[str, CommitOperationAdd] = {}
  315. for chunk in chunk_iterable(additions, chunk_size=UPLOAD_BATCH_MAX_NUM_FILES):
  316. chunk_list = [op for op in chunk]
  317. transfers: List[str] = ["basic", "multipart"]
  318. has_buffered_io_data = any(isinstance(op.path_or_fileobj, io.BufferedIOBase) for op in chunk_list)
  319. if is_xet_available():
  320. if not has_buffered_io_data:
  321. transfers.append("xet")
  322. else:
  323. logger.warning(
  324. "Uploading files as a binary IO buffer is not supported by Xet Storage. "
  325. "Falling back to HTTP upload."
  326. )
  327. actions_chunk, errors_chunk, chosen_transfer = post_lfs_batch_info(
  328. upload_infos=[op.upload_info for op in chunk_list],
  329. repo_id=repo_id,
  330. repo_type=repo_type,
  331. revision=revision,
  332. endpoint=endpoint,
  333. headers=headers,
  334. token=None, # already passed in 'headers'
  335. transfers=transfers,
  336. )
  337. if errors_chunk:
  338. message = "\n".join(
  339. [
  340. f"Encountered error for file with OID {err.get('oid')}: `{err.get('error', {}).get('message')}"
  341. for err in errors_chunk
  342. ]
  343. )
  344. raise ValueError(f"LFS batch API returned errors:\n{message}")
  345. # If server returns a transfer we didn't offer (e.g "xet" while uploading from BytesIO),
  346. # fall back to LFS for this chunk.
  347. if chosen_transfer == "xet" and ("xet" in transfers):
  348. xet_additions.extend(chunk_list)
  349. else:
  350. lfs_actions.extend(actions_chunk)
  351. for op in chunk_list:
  352. lfs_oid2addop[op.upload_info.sha256.hex()] = op
  353. if len(lfs_actions) > 0:
  354. _upload_lfs_files(
  355. actions=lfs_actions,
  356. oid2addop=lfs_oid2addop,
  357. headers=headers,
  358. endpoint=endpoint,
  359. num_threads=num_threads,
  360. )
  361. if len(xet_additions) > 0:
  362. _upload_xet_files(
  363. additions=xet_additions,
  364. repo_type=repo_type,
  365. repo_id=repo_id,
  366. headers=headers,
  367. endpoint=endpoint,
  368. revision=revision,
  369. create_pr=create_pr,
  370. )
  371. @validate_hf_hub_args
  372. def _upload_lfs_files(
  373. *,
  374. actions: List[Dict],
  375. oid2addop: Dict[str, CommitOperationAdd],
  376. headers: Dict[str, str],
  377. endpoint: Optional[str] = None,
  378. num_threads: int = 5,
  379. ):
  380. """
  381. Uploads the content of `additions` to the Hub using the large file storage protocol.
  382. Relevant external documentation:
  383. - LFS Batch API: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md
  384. Args:
  385. actions (`List[Dict]`):
  386. LFS batch actions returned by the server.
  387. oid2addop (`Dict[str, CommitOperationAdd]`):
  388. A dictionary mapping the OID of the file to the corresponding `CommitOperationAdd` object.
  389. headers (`Dict[str, str]`):
  390. Headers to use for the request, including authorization headers and user agent.
  391. endpoint (`str`, *optional*):
  392. The endpoint to use for the request. Defaults to `constants.ENDPOINT`.
  393. num_threads (`int`, *optional*):
  394. The number of concurrent threads to use when uploading. Defaults to 5.
  395. Raises:
  396. [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
  397. If an upload failed for any reason
  398. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  399. Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
  400. repo_id (`str`):
  401. A namespace (user or an organization) and a repo name separated
  402. by a `/`.
  403. headers (`Dict[str, str]`):
  404. Headers to use for the request, including authorization headers and user agent.
  405. num_threads (`int`, *optional*):
  406. The number of concurrent threads to use when uploading. Defaults to 5.
  407. revision (`str`, *optional*):
  408. The git revision to upload to.
  409. Raises:
  410. [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
  411. If an upload failed for any reason
  412. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  413. If the server returns malformed responses
  414. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError)
  415. If the LFS batch endpoint returned an HTTP error.
  416. """
  417. # Filter out files already present upstream
  418. filtered_actions = []
  419. for action in actions:
  420. if action.get("actions") is None:
  421. logger.debug(
  422. f"Content of file {oid2addop[action['oid']].path_in_repo} is already present upstream - skipping upload."
  423. )
  424. else:
  425. filtered_actions.append(action)
  426. # Upload according to server-provided actions
  427. def _wrapped_lfs_upload(batch_action) -> None:
  428. try:
  429. operation = oid2addop[batch_action["oid"]]
  430. lfs_upload(operation=operation, lfs_batch_action=batch_action, headers=headers, endpoint=endpoint)
  431. except Exception as exc:
  432. raise RuntimeError(f"Error while uploading '{operation.path_in_repo}' to the Hub.") from exc
  433. if constants.HF_HUB_ENABLE_HF_TRANSFER:
  434. logger.debug(f"Uploading {len(filtered_actions)} LFS files to the Hub using `hf_transfer`.")
  435. for action in hf_tqdm(filtered_actions, name="huggingface_hub.lfs_upload"):
  436. _wrapped_lfs_upload(action)
  437. elif len(filtered_actions) == 1:
  438. logger.debug("Uploading 1 LFS file to the Hub")
  439. _wrapped_lfs_upload(filtered_actions[0])
  440. else:
  441. logger.debug(
  442. f"Uploading {len(filtered_actions)} LFS files to the Hub using up to {num_threads} threads concurrently"
  443. )
  444. thread_map(
  445. _wrapped_lfs_upload,
  446. filtered_actions,
  447. desc=f"Upload {len(filtered_actions)} LFS files",
  448. max_workers=num_threads,
  449. tqdm_class=hf_tqdm,
  450. )
  451. @validate_hf_hub_args
  452. def _upload_xet_files(
  453. *,
  454. additions: List[CommitOperationAdd],
  455. repo_type: str,
  456. repo_id: str,
  457. headers: Dict[str, str],
  458. endpoint: Optional[str] = None,
  459. revision: Optional[str] = None,
  460. create_pr: Optional[bool] = None,
  461. ):
  462. """
  463. Uploads the content of `additions` to the Hub using the xet storage protocol.
  464. This chunks the files and deduplicates the chunks before uploading them to xetcas storage.
  465. Args:
  466. additions (`List` of `CommitOperationAdd`):
  467. The files to be uploaded.
  468. repo_type (`str`):
  469. Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
  470. repo_id (`str`):
  471. A namespace (user or an organization) and a repo name separated
  472. by a `/`.
  473. headers (`Dict[str, str]`):
  474. Headers to use for the request, including authorization headers and user agent.
  475. endpoint: (`str`, *optional*):
  476. The endpoint to use for the xetcas service. Defaults to `constants.ENDPOINT`.
  477. revision (`str`, *optional*):
  478. The git revision to upload to.
  479. create_pr (`bool`, *optional*):
  480. Whether or not to create a Pull Request with that commit.
  481. Raises:
  482. [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
  483. If an upload failed for any reason.
  484. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  485. If the server returns malformed responses or if the user is unauthorized to upload to xet storage.
  486. [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError)
  487. If the LFS batch endpoint returned an HTTP error.
  488. **How it works:**
  489. The file download system uses Xet storage, which is a content-addressable storage system that breaks files into chunks
  490. for efficient storage and transfer.
  491. `hf_xet.upload_files` manages uploading files by:
  492. - Taking a list of file paths to upload
  493. - Breaking files into smaller chunks for efficient storage
  494. - Avoiding duplicate storage by recognizing identical chunks across files
  495. - Connecting to a storage server (CAS server) that manages these chunks
  496. The upload process works like this:
  497. 1. Create a local folder at ~/.cache/huggingface/xet/chunk-cache to store file chunks for reuse.
  498. 2. Process files in parallel (up to 8 files at once):
  499. 2.1. Read the file content.
  500. 2.2. Split the file content into smaller chunks based on content patterns: each chunk gets a unique ID based on what's in it.
  501. 2.3. For each chunk:
  502. - Check if it already exists in storage.
  503. - Skip uploading chunks that already exist.
  504. 2.4. Group chunks into larger blocks for efficient transfer.
  505. 2.5. Upload these blocks to the storage server.
  506. 2.6. Create and upload information about how the file is structured.
  507. 3. Return reference files that contain information about the uploaded files, which can be used later to download them.
  508. """
  509. if len(additions) == 0:
  510. return
  511. # at this point, we know that hf_xet is installed
  512. from hf_xet import upload_bytes, upload_files
  513. from .utils._xet_progress_reporting import XetProgressReporter
  514. try:
  515. xet_connection_info = fetch_xet_connection_info_from_repo_info(
  516. token_type=XetTokenType.WRITE,
  517. repo_id=repo_id,
  518. repo_type=repo_type,
  519. revision=revision,
  520. headers=headers,
  521. endpoint=endpoint,
  522. params={"create_pr": "1"} if create_pr else None,
  523. )
  524. except HfHubHTTPError as e:
  525. if e.response.status_code == 401:
  526. raise XetAuthorizationError(
  527. f"You are unauthorized to upload to xet storage for {repo_type}/{repo_id}. "
  528. f"Please check that you have configured your access token with write access to the repo."
  529. ) from e
  530. raise
  531. xet_endpoint = xet_connection_info.endpoint
  532. access_token_info = (xet_connection_info.access_token, xet_connection_info.expiration_unix_epoch)
  533. def token_refresher() -> Tuple[str, int]:
  534. new_xet_connection = fetch_xet_connection_info_from_repo_info(
  535. token_type=XetTokenType.WRITE,
  536. repo_id=repo_id,
  537. repo_type=repo_type,
  538. revision=revision,
  539. headers=headers,
  540. endpoint=endpoint,
  541. params={"create_pr": "1"} if create_pr else None,
  542. )
  543. if new_xet_connection is None:
  544. raise XetRefreshTokenError("Failed to refresh xet token")
  545. return new_xet_connection.access_token, new_xet_connection.expiration_unix_epoch
  546. if not are_progress_bars_disabled():
  547. progress = XetProgressReporter()
  548. progress_callback = progress.update_progress
  549. else:
  550. progress, progress_callback = None, None
  551. try:
  552. all_bytes_ops = [op for op in additions if isinstance(op.path_or_fileobj, bytes)]
  553. all_paths_ops = [op for op in additions if isinstance(op.path_or_fileobj, (str, Path))]
  554. if len(all_paths_ops) > 0:
  555. all_paths = [str(op.path_or_fileobj) for op in all_paths_ops]
  556. upload_files(
  557. all_paths,
  558. xet_endpoint,
  559. access_token_info,
  560. token_refresher,
  561. progress_callback,
  562. repo_type,
  563. )
  564. if len(all_bytes_ops) > 0:
  565. all_bytes = [op.path_or_fileobj for op in all_bytes_ops]
  566. upload_bytes(
  567. all_bytes,
  568. xet_endpoint,
  569. access_token_info,
  570. token_refresher,
  571. progress_callback,
  572. repo_type,
  573. )
  574. finally:
  575. if progress is not None:
  576. progress.close(False)
  577. return
  578. def _validate_preupload_info(preupload_info: dict):
  579. files = preupload_info.get("files")
  580. if not isinstance(files, list):
  581. raise ValueError("preupload_info is improperly formatted")
  582. for file_info in files:
  583. if not (
  584. isinstance(file_info, dict)
  585. and isinstance(file_info.get("path"), str)
  586. and isinstance(file_info.get("uploadMode"), str)
  587. and (file_info["uploadMode"] in ("lfs", "regular"))
  588. ):
  589. raise ValueError("preupload_info is improperly formatted:")
  590. return preupload_info
  591. @validate_hf_hub_args
  592. def _fetch_upload_modes(
  593. additions: Iterable[CommitOperationAdd],
  594. repo_type: str,
  595. repo_id: str,
  596. headers: Dict[str, str],
  597. revision: str,
  598. endpoint: Optional[str] = None,
  599. create_pr: bool = False,
  600. gitignore_content: Optional[str] = None,
  601. ) -> None:
  602. """
  603. Requests the Hub "preupload" endpoint to determine whether each input file should be uploaded as a regular git blob,
  604. as a git LFS blob, or as a XET file. Input `additions` are mutated in-place with the upload mode.
  605. Args:
  606. additions (`Iterable` of :class:`CommitOperationAdd`):
  607. Iterable of :class:`CommitOperationAdd` describing the files to
  608. upload to the Hub.
  609. repo_type (`str`):
  610. Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
  611. repo_id (`str`):
  612. A namespace (user or an organization) and a repo name separated
  613. by a `/`.
  614. headers (`Dict[str, str]`):
  615. Headers to use for the request, including authorization headers and user agent.
  616. revision (`str`):
  617. The git revision to upload the files to. Can be any valid git revision.
  618. gitignore_content (`str`, *optional*):
  619. The content of the `.gitignore` file to know which files should be ignored. The order of priority
  620. is to first check if `gitignore_content` is passed, then check if the `.gitignore` file is present
  621. in the list of files to commit and finally default to the `.gitignore` file already hosted on the Hub
  622. (if any).
  623. Raises:
  624. [`~utils.HfHubHTTPError`]
  625. If the Hub API returned an error.
  626. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  627. If the Hub API response is improperly formatted.
  628. """
  629. endpoint = endpoint if endpoint is not None else constants.ENDPOINT
  630. # Fetch upload mode (LFS or regular) chunk by chunk.
  631. upload_modes: Dict[str, UploadMode] = {}
  632. should_ignore_info: Dict[str, bool] = {}
  633. oid_info: Dict[str, Optional[str]] = {}
  634. for chunk in chunk_iterable(additions, 256):
  635. payload: Dict = {
  636. "files": [
  637. {
  638. "path": op.path_in_repo,
  639. "sample": base64.b64encode(op.upload_info.sample).decode("ascii"),
  640. "size": op.upload_info.size,
  641. }
  642. for op in chunk
  643. ]
  644. }
  645. if gitignore_content is not None:
  646. payload["gitIgnore"] = gitignore_content
  647. resp = get_session().post(
  648. f"{endpoint}/api/{repo_type}s/{repo_id}/preupload/{revision}",
  649. json=payload,
  650. headers=headers,
  651. params={"create_pr": "1"} if create_pr else None,
  652. )
  653. hf_raise_for_status(resp)
  654. preupload_info = _validate_preupload_info(resp.json())
  655. upload_modes.update(**{file["path"]: file["uploadMode"] for file in preupload_info["files"]})
  656. should_ignore_info.update(**{file["path"]: file["shouldIgnore"] for file in preupload_info["files"]})
  657. oid_info.update(**{file["path"]: file.get("oid") for file in preupload_info["files"]})
  658. # Set upload mode for each addition operation
  659. for addition in additions:
  660. addition._upload_mode = upload_modes[addition.path_in_repo]
  661. addition._should_ignore = should_ignore_info[addition.path_in_repo]
  662. addition._remote_oid = oid_info[addition.path_in_repo]
  663. # Empty files cannot be uploaded as LFS (S3 would fail with a 501 Not Implemented)
  664. # => empty files are uploaded as "regular" to still allow users to commit them.
  665. for addition in additions:
  666. if addition.upload_info.size == 0:
  667. addition._upload_mode = "regular"
  668. @validate_hf_hub_args
  669. def _fetch_files_to_copy(
  670. copies: Iterable[CommitOperationCopy],
  671. repo_type: str,
  672. repo_id: str,
  673. headers: Dict[str, str],
  674. revision: str,
  675. endpoint: Optional[str] = None,
  676. ) -> Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]]:
  677. """
  678. Fetch information about the files to copy.
  679. For LFS files, we only need their metadata (file size and sha256) while for regular files
  680. we need to download the raw content from the Hub.
  681. Args:
  682. copies (`Iterable` of :class:`CommitOperationCopy`):
  683. Iterable of :class:`CommitOperationCopy` describing the files to
  684. copy on the Hub.
  685. repo_type (`str`):
  686. Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
  687. repo_id (`str`):
  688. A namespace (user or an organization) and a repo name separated
  689. by a `/`.
  690. headers (`Dict[str, str]`):
  691. Headers to use for the request, including authorization headers and user agent.
  692. revision (`str`):
  693. The git revision to upload the files to. Can be any valid git revision.
  694. Returns: `Dict[Tuple[str, Optional[str]], Union[RepoFile, bytes]]]`
  695. Key is the file path and revision of the file to copy.
  696. Value is the raw content as bytes (for regular files) or the file information as a RepoFile (for LFS files).
  697. Raises:
  698. [`~utils.HfHubHTTPError`]
  699. If the Hub API returned an error.
  700. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  701. If the Hub API response is improperly formatted.
  702. """
  703. from .hf_api import HfApi, RepoFolder
  704. hf_api = HfApi(endpoint=endpoint, headers=headers)
  705. files_to_copy: Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]] = {}
  706. # Store (path, revision) -> oid mapping
  707. oid_info: Dict[Tuple[str, Optional[str]], Optional[str]] = {}
  708. # 1. Fetch OIDs for destination paths in batches.
  709. dest_paths = [op.path_in_repo for op in copies]
  710. for offset in range(0, len(dest_paths), FETCH_LFS_BATCH_SIZE):
  711. dest_repo_files = hf_api.get_paths_info(
  712. repo_id=repo_id,
  713. paths=dest_paths[offset : offset + FETCH_LFS_BATCH_SIZE],
  714. revision=revision,
  715. repo_type=repo_type,
  716. )
  717. for file in dest_repo_files:
  718. if not isinstance(file, RepoFolder):
  719. oid_info[(file.path, revision)] = file.blob_id
  720. # 2. Group by source revision and fetch source file info in batches.
  721. for src_revision, operations in groupby(copies, key=lambda op: op.src_revision):
  722. operations = list(operations) # type: ignore
  723. src_paths = [op.src_path_in_repo for op in operations]
  724. for offset in range(0, len(src_paths), FETCH_LFS_BATCH_SIZE):
  725. src_repo_files = hf_api.get_paths_info(
  726. repo_id=repo_id,
  727. paths=src_paths[offset : offset + FETCH_LFS_BATCH_SIZE],
  728. revision=src_revision or revision,
  729. repo_type=repo_type,
  730. )
  731. for src_repo_file in src_repo_files:
  732. if isinstance(src_repo_file, RepoFolder):
  733. raise NotImplementedError("Copying a folder is not implemented.")
  734. oid_info[(src_repo_file.path, src_revision)] = src_repo_file.blob_id
  735. # If it's an LFS file, store the RepoFile object. Otherwise, download raw bytes.
  736. if src_repo_file.lfs:
  737. files_to_copy[(src_repo_file.path, src_revision)] = src_repo_file
  738. else:
  739. # TODO: (optimization) download regular files to copy concurrently
  740. url = hf_hub_url(
  741. endpoint=endpoint,
  742. repo_type=repo_type,
  743. repo_id=repo_id,
  744. revision=src_revision or revision,
  745. filename=src_repo_file.path,
  746. )
  747. response = get_session().get(url, headers=headers)
  748. hf_raise_for_status(response)
  749. files_to_copy[(src_repo_file.path, src_revision)] = response.content
  750. # 3. Ensure all operations found a corresponding file in the Hub
  751. # and track src/dest OIDs for each operation.
  752. for operation in operations:
  753. if (operation.src_path_in_repo, src_revision) not in files_to_copy:
  754. raise EntryNotFoundError(
  755. f"Cannot copy {operation.src_path_in_repo} at revision "
  756. f"{src_revision or revision}: file is missing on repo."
  757. )
  758. operation._src_oid = oid_info.get((operation.src_path_in_repo, operation.src_revision))
  759. operation._dest_oid = oid_info.get((operation.path_in_repo, revision))
  760. return files_to_copy
  761. def _prepare_commit_payload(
  762. operations: Iterable[CommitOperation],
  763. files_to_copy: Dict[Tuple[str, Optional[str]], Union["RepoFile", bytes]],
  764. commit_message: str,
  765. commit_description: Optional[str] = None,
  766. parent_commit: Optional[str] = None,
  767. ) -> Iterable[Dict[str, Any]]:
  768. """
  769. Builds the payload to POST to the `/commit` API of the Hub.
  770. Payload is returned as an iterator so that it can be streamed as a ndjson in the
  771. POST request.
  772. For more information, see:
  773. - https://github.com/huggingface/huggingface_hub/issues/1085#issuecomment-1265208073
  774. - http://ndjson.org/
  775. """
  776. commit_description = commit_description if commit_description is not None else ""
  777. # 1. Send a header item with the commit metadata
  778. header_value = {"summary": commit_message, "description": commit_description}
  779. if parent_commit is not None:
  780. header_value["parentCommit"] = parent_commit
  781. yield {"key": "header", "value": header_value}
  782. nb_ignored_files = 0
  783. # 2. Send operations, one per line
  784. for operation in operations:
  785. # Skip ignored files
  786. if isinstance(operation, CommitOperationAdd) and operation._should_ignore:
  787. logger.debug(f"Skipping file '{operation.path_in_repo}' in commit (ignored by gitignore file).")
  788. nb_ignored_files += 1
  789. continue
  790. # 2.a. Case adding a regular file
  791. if isinstance(operation, CommitOperationAdd) and operation._upload_mode == "regular":
  792. yield {
  793. "key": "file",
  794. "value": {
  795. "content": operation.b64content().decode(),
  796. "path": operation.path_in_repo,
  797. "encoding": "base64",
  798. },
  799. }
  800. # 2.b. Case adding an LFS file
  801. elif isinstance(operation, CommitOperationAdd) and operation._upload_mode == "lfs":
  802. yield {
  803. "key": "lfsFile",
  804. "value": {
  805. "path": operation.path_in_repo,
  806. "algo": "sha256",
  807. "oid": operation.upload_info.sha256.hex(),
  808. "size": operation.upload_info.size,
  809. },
  810. }
  811. # 2.c. Case deleting a file or folder
  812. elif isinstance(operation, CommitOperationDelete):
  813. yield {
  814. "key": "deletedFolder" if operation.is_folder else "deletedFile",
  815. "value": {"path": operation.path_in_repo},
  816. }
  817. # 2.d. Case copying a file or folder
  818. elif isinstance(operation, CommitOperationCopy):
  819. file_to_copy = files_to_copy[(operation.src_path_in_repo, operation.src_revision)]
  820. if isinstance(file_to_copy, bytes):
  821. yield {
  822. "key": "file",
  823. "value": {
  824. "content": base64.b64encode(file_to_copy).decode(),
  825. "path": operation.path_in_repo,
  826. "encoding": "base64",
  827. },
  828. }
  829. elif file_to_copy.lfs:
  830. yield {
  831. "key": "lfsFile",
  832. "value": {
  833. "path": operation.path_in_repo,
  834. "algo": "sha256",
  835. "oid": file_to_copy.lfs.sha256,
  836. },
  837. }
  838. else:
  839. raise ValueError(
  840. "Malformed files_to_copy (should be raw file content as bytes or RepoFile objects with LFS info."
  841. )
  842. # 2.e. Never expected to happen
  843. else:
  844. raise ValueError(
  845. f"Unknown operation to commit. Operation: {operation}. Upload mode:"
  846. f" {getattr(operation, '_upload_mode', None)}"
  847. )
  848. if nb_ignored_files > 0:
  849. logger.info(f"Skipped {nb_ignored_files} file(s) in commit (ignored by gitignore file).")