safetensors_conversion.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from typing import Optional
  2. import requests
  3. from huggingface_hub import Discussion, HfApi, get_repo_discussions
  4. from .utils import cached_file, http_user_agent, logging
  5. logger = logging.get_logger(__name__)
  6. def previous_pr(api: HfApi, model_id: str, pr_title: str, token: str) -> Optional["Discussion"]:
  7. main_commit = api.list_repo_commits(model_id, token=token)[0].commit_id
  8. for discussion in get_repo_discussions(repo_id=model_id, token=token):
  9. if discussion.title == pr_title and discussion.status == "open" and discussion.is_pull_request:
  10. commits = api.list_repo_commits(model_id, revision=discussion.git_reference, token=token)
  11. if main_commit == commits[1].commit_id:
  12. return discussion
  13. return None
  14. def spawn_conversion(token: str, private: bool, model_id: str):
  15. logger.info("Attempting to convert .bin model on the fly to safetensors.")
  16. safetensors_convert_space_url = "https://safetensors-convert.hf.space"
  17. sse_url = f"{safetensors_convert_space_url}/call/run"
  18. def start(_sse_connection):
  19. for line in _sse_connection.iter_lines():
  20. line = line.decode()
  21. if line.startswith("event:"):
  22. status = line[7:]
  23. logger.debug(f"Safetensors conversion status: {status}")
  24. if status == "complete":
  25. return
  26. elif status == "heartbeat":
  27. logger.debug("Heartbeat")
  28. else:
  29. logger.debug(f"Unknown status {status}")
  30. else:
  31. logger.debug(line)
  32. data = {"data": [model_id, private, token]}
  33. result = requests.post(sse_url, stream=True, json=data).json()
  34. event_id = result["event_id"]
  35. with requests.get(f"{sse_url}/{event_id}", stream=True) as sse_connection:
  36. try:
  37. logger.debug("Spawning safetensors automatic conversion.")
  38. start(sse_connection)
  39. except Exception as e:
  40. logger.warning(f"Error during conversion: {repr(e)}")
  41. def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs):
  42. private = api.model_info(model_id).private
  43. logger.info("Attempting to create safetensors variant")
  44. pr_title = "Adding `safetensors` variant of this model"
  45. token = kwargs.get("token")
  46. # This looks into the current repo's open PRs to see if a PR for safetensors was already open. If so, it
  47. # returns it. It checks that the PR was opened by the bot and not by another user so as to prevent
  48. # security breaches.
  49. pr = previous_pr(api, model_id, pr_title, token=token)
  50. if pr is None or (not private and pr.author != "SFconvertbot"):
  51. spawn_conversion(token, private, model_id)
  52. pr = previous_pr(api, model_id, pr_title, token=token)
  53. else:
  54. logger.info("Safetensors PR exists")
  55. sha = f"refs/pr/{pr.num}"
  56. return sha
  57. def auto_conversion(pretrained_model_name_or_path: str, ignore_errors_during_conversion=False, **cached_file_kwargs):
  58. try:
  59. api = HfApi(token=cached_file_kwargs.get("token"), headers={"user-agent": http_user_agent()})
  60. sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)
  61. if sha is None:
  62. return None, None
  63. cached_file_kwargs["revision"] = sha
  64. del cached_file_kwargs["_commit_hash"]
  65. # This is an additional HEAD call that could be removed if we could infer sharded/non-sharded from the PR
  66. # description.
  67. sharded = api.file_exists(
  68. pretrained_model_name_or_path,
  69. "model.safetensors.index.json",
  70. revision=sha,
  71. token=cached_file_kwargs.get("token"),
  72. )
  73. filename = "model.safetensors.index.json" if sharded else "model.safetensors"
  74. resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
  75. return resolved_archive_file, sha, sharded
  76. except Exception as e:
  77. if not ignore_errors_during_conversion:
  78. raise e