lfs.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. """
  2. Implementation of a custom transfer agent for the transfer type "multipart" for
  3. git-lfs.
  4. Inspired by:
  5. github.com/cbartz/git-lfs-swift-transfer-agent/blob/master/git_lfs_swift_transfer.py
  6. Spec is: github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md
  7. To launch debugger while developing:
  8. ``` [lfs "customtransfer.multipart"]
  9. path = /path/to/huggingface_hub/.env/bin/python args = -m debugpy --listen 5678
  10. --wait-for-client
  11. /path/to/huggingface_hub/src/huggingface_hub/commands/huggingface_cli.py
  12. lfs-multipart-upload ```"""
  13. import json
  14. import os
  15. import subprocess
  16. import sys
  17. from argparse import _SubParsersAction
  18. from typing import Dict, List, Optional
  19. from huggingface_hub.commands import BaseHuggingfaceCLICommand
  20. from huggingface_hub.lfs import LFS_MULTIPART_UPLOAD_COMMAND
  21. from ..utils import get_session, hf_raise_for_status, logging
  22. from ..utils._lfs import SliceFileObj
  23. logger = logging.get_logger(__name__)
  24. class LfsCommands(BaseHuggingfaceCLICommand):
  25. """
  26. Implementation of a custom transfer agent for the transfer type "multipart"
  27. for git-lfs. This lets users upload large files >5GB 🔥. Spec for LFS custom
  28. transfer agent is:
  29. https://github.com/git-lfs/git-lfs/blob/master/docs/custom-transfers.md
  30. This introduces two commands to the CLI:
  31. 1. $ hf lfs-enable-largefiles
  32. This should be executed once for each model repo that contains a model file
  33. >5GB. It's documented in the error message you get if you just try to git
  34. push a 5GB file without having enabled it before.
  35. 2. $ hf lfs-multipart-upload
  36. This command is called by lfs directly and is not meant to be called by the
  37. user.
  38. """
  39. @staticmethod
  40. def register_subcommand(parser: _SubParsersAction):
  41. enable_parser = parser.add_parser("lfs-enable-largefiles", add_help=False)
  42. enable_parser.add_argument("path", type=str, help="Local path to repository you want to configure.")
  43. enable_parser.set_defaults(func=lambda args: LfsEnableCommand(args))
  44. # Command will get called by git-lfs, do not call it directly.
  45. upload_parser = parser.add_parser(LFS_MULTIPART_UPLOAD_COMMAND, add_help=False)
  46. upload_parser.set_defaults(func=lambda args: LfsUploadCommand(args))
  47. class LfsEnableCommand:
  48. def __init__(self, args):
  49. self.args = args
  50. def run(self):
  51. local_path = os.path.abspath(self.args.path)
  52. if not os.path.isdir(local_path):
  53. print("This does not look like a valid git repo.")
  54. exit(1)
  55. subprocess.run(
  56. "git config lfs.customtransfer.multipart.path hf".split(),
  57. check=True,
  58. cwd=local_path,
  59. )
  60. subprocess.run(
  61. f"git config lfs.customtransfer.multipart.args {LFS_MULTIPART_UPLOAD_COMMAND}".split(),
  62. check=True,
  63. cwd=local_path,
  64. )
  65. print("Local repo set up for largefiles")
  66. def write_msg(msg: Dict):
  67. """Write out the message in Line delimited JSON."""
  68. msg_str = json.dumps(msg) + "\n"
  69. sys.stdout.write(msg_str)
  70. sys.stdout.flush()
  71. def read_msg() -> Optional[Dict]:
  72. """Read Line delimited JSON from stdin."""
  73. msg = json.loads(sys.stdin.readline().strip())
  74. if "terminate" in (msg.get("type"), msg.get("event")):
  75. # terminate message received
  76. return None
  77. if msg.get("event") not in ("download", "upload"):
  78. logger.critical("Received unexpected message")
  79. sys.exit(1)
  80. return msg
  81. class LfsUploadCommand:
  82. def __init__(self, args) -> None:
  83. self.args = args
  84. def run(self) -> None:
  85. # Immediately after invoking a custom transfer process, git-lfs
  86. # sends initiation data to the process over stdin.
  87. # This tells the process useful information about the configuration.
  88. init_msg = json.loads(sys.stdin.readline().strip())
  89. if not (init_msg.get("event") == "init" and init_msg.get("operation") == "upload"):
  90. write_msg({"error": {"code": 32, "message": "Wrong lfs init operation"}})
  91. sys.exit(1)
  92. # The transfer process should use the information it needs from the
  93. # initiation structure, and also perform any one-off setup tasks it
  94. # needs to do. It should then respond on stdout with a simple empty
  95. # confirmation structure, as follows:
  96. write_msg({})
  97. # After the initiation exchange, git-lfs will send any number of
  98. # transfer requests to the stdin of the transfer process, in a serial sequence.
  99. while True:
  100. msg = read_msg()
  101. if msg is None:
  102. # When all transfers have been processed, git-lfs will send
  103. # a terminate event to the stdin of the transfer process.
  104. # On receiving this message the transfer process should
  105. # clean up and terminate. No response is expected.
  106. sys.exit(0)
  107. oid = msg["oid"]
  108. filepath = msg["path"]
  109. completion_url = msg["action"]["href"]
  110. header = msg["action"]["header"]
  111. chunk_size = int(header.pop("chunk_size"))
  112. presigned_urls: List[str] = list(header.values())
  113. # Send a "started" progress event to allow other workers to start.
  114. # Otherwise they're delayed until first "progress" event is reported,
  115. # i.e. after the first 5GB by default (!)
  116. write_msg(
  117. {
  118. "event": "progress",
  119. "oid": oid,
  120. "bytesSoFar": 1,
  121. "bytesSinceLast": 0,
  122. }
  123. )
  124. parts = []
  125. with open(filepath, "rb") as file:
  126. for i, presigned_url in enumerate(presigned_urls):
  127. with SliceFileObj(
  128. file,
  129. seek_from=i * chunk_size,
  130. read_limit=chunk_size,
  131. ) as data:
  132. r = get_session().put(presigned_url, data=data)
  133. hf_raise_for_status(r)
  134. parts.append(
  135. {
  136. "etag": r.headers.get("etag"),
  137. "partNumber": i + 1,
  138. }
  139. )
  140. # In order to support progress reporting while data is uploading / downloading,
  141. # the transfer process should post messages to stdout
  142. write_msg(
  143. {
  144. "event": "progress",
  145. "oid": oid,
  146. "bytesSoFar": (i + 1) * chunk_size,
  147. "bytesSinceLast": chunk_size,
  148. }
  149. )
  150. # Not precise but that's ok.
  151. r = get_session().post(
  152. completion_url,
  153. json={
  154. "oid": oid,
  155. "parts": parts,
  156. },
  157. )
  158. hf_raise_for_status(r)
  159. write_msg({"event": "complete", "oid": oid})