_xet.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. from dataclasses import dataclass
  2. from enum import Enum
  3. from typing import Dict, Optional
  4. import requests
  5. from .. import constants
  6. from . import get_session, hf_raise_for_status, validate_hf_hub_args
  7. class XetTokenType(str, Enum):
  8. READ = "read"
  9. WRITE = "write"
  10. @dataclass(frozen=True)
  11. class XetFileData:
  12. file_hash: str
  13. refresh_route: str
  14. @dataclass(frozen=True)
  15. class XetConnectionInfo:
  16. access_token: str
  17. expiration_unix_epoch: int
  18. endpoint: str
  19. def parse_xet_file_data_from_response(
  20. response: requests.Response, endpoint: Optional[str] = None
  21. ) -> Optional[XetFileData]:
  22. """
  23. Parse XET file metadata from an HTTP response.
  24. This function extracts XET file metadata from the HTTP headers or HTTP links
  25. of a given response object. If the required metadata is not found, it returns `None`.
  26. Args:
  27. response (`requests.Response`):
  28. The HTTP response object containing headers dict and links dict to extract the XET metadata from.
  29. Returns:
  30. `Optional[XetFileData]`:
  31. An instance of `XetFileData` containing the file hash and refresh route if the metadata
  32. is found. Returns `None` if the required metadata is missing.
  33. """
  34. if response is None:
  35. return None
  36. try:
  37. file_hash = response.headers[constants.HUGGINGFACE_HEADER_X_XET_HASH]
  38. if constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY in response.links:
  39. refresh_route = response.links[constants.HUGGINGFACE_HEADER_LINK_XET_AUTH_KEY]["url"]
  40. else:
  41. refresh_route = response.headers[constants.HUGGINGFACE_HEADER_X_XET_REFRESH_ROUTE]
  42. except KeyError:
  43. return None
  44. endpoint = endpoint if endpoint is not None else constants.ENDPOINT
  45. if refresh_route.startswith(constants.HUGGINGFACE_CO_URL_HOME):
  46. refresh_route = refresh_route.replace(constants.HUGGINGFACE_CO_URL_HOME.rstrip("/"), endpoint.rstrip("/"))
  47. return XetFileData(
  48. file_hash=file_hash,
  49. refresh_route=refresh_route,
  50. )
  51. def parse_xet_connection_info_from_headers(headers: Dict[str, str]) -> Optional[XetConnectionInfo]:
  52. """
  53. Parse XET connection info from the HTTP headers or return None if not found.
  54. Args:
  55. headers (`Dict`):
  56. HTTP headers to extract the XET metadata from.
  57. Returns:
  58. `XetConnectionInfo` or `None`:
  59. The information needed to connect to the XET storage service.
  60. Returns `None` if the headers do not contain the XET connection info.
  61. """
  62. try:
  63. endpoint = headers[constants.HUGGINGFACE_HEADER_X_XET_ENDPOINT]
  64. access_token = headers[constants.HUGGINGFACE_HEADER_X_XET_ACCESS_TOKEN]
  65. expiration_unix_epoch = int(headers[constants.HUGGINGFACE_HEADER_X_XET_EXPIRATION])
  66. except (KeyError, ValueError, TypeError):
  67. return None
  68. return XetConnectionInfo(
  69. endpoint=endpoint,
  70. access_token=access_token,
  71. expiration_unix_epoch=expiration_unix_epoch,
  72. )
  73. @validate_hf_hub_args
  74. def refresh_xet_connection_info(
  75. *,
  76. file_data: XetFileData,
  77. headers: Dict[str, str],
  78. ) -> XetConnectionInfo:
  79. """
  80. Utilizes the information in the parsed metadata to request the Hub xet connection information.
  81. This includes the access token, expiration, and XET service URL.
  82. Args:
  83. file_data: (`XetFileData`):
  84. The file data needed to refresh the xet connection information.
  85. headers (`Dict[str, str]`):
  86. Headers to use for the request, including authorization headers and user agent.
  87. Returns:
  88. `XetConnectionInfo`:
  89. The connection information needed to make the request to the xet storage service.
  90. Raises:
  91. [`~utils.HfHubHTTPError`]
  92. If the Hub API returned an error.
  93. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  94. If the Hub API response is improperly formatted.
  95. """
  96. if file_data.refresh_route is None:
  97. raise ValueError("The provided xet metadata does not contain a refresh endpoint.")
  98. return _fetch_xet_connection_info_with_url(file_data.refresh_route, headers)
  99. @validate_hf_hub_args
  100. def fetch_xet_connection_info_from_repo_info(
  101. *,
  102. token_type: XetTokenType,
  103. repo_id: str,
  104. repo_type: str,
  105. revision: Optional[str] = None,
  106. headers: Dict[str, str],
  107. endpoint: Optional[str] = None,
  108. params: Optional[Dict[str, str]] = None,
  109. ) -> XetConnectionInfo:
  110. """
  111. Uses the repo info to request a xet access token from Hub.
  112. Args:
  113. token_type (`XetTokenType`):
  114. Type of the token to request: `"read"` or `"write"`.
  115. repo_id (`str`):
  116. A namespace (user or an organization) and a repo name separated by a `/`.
  117. repo_type (`str`):
  118. Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
  119. revision (`str`, `optional`):
  120. The revision of the repo to get the token for.
  121. headers (`Dict[str, str]`):
  122. Headers to use for the request, including authorization headers and user agent.
  123. endpoint (`str`, `optional`):
  124. The endpoint to use for the request. Defaults to the Hub endpoint.
  125. params (`Dict[str, str]`, `optional`):
  126. Additional parameters to pass with the request.
  127. Returns:
  128. `XetConnectionInfo`:
  129. The connection information needed to make the request to the xet storage service.
  130. Raises:
  131. [`~utils.HfHubHTTPError`]
  132. If the Hub API returned an error.
  133. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  134. If the Hub API response is improperly formatted.
  135. """
  136. endpoint = endpoint if endpoint is not None else constants.ENDPOINT
  137. url = f"{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type.value}-token/{revision}"
  138. return _fetch_xet_connection_info_with_url(url, headers, params)
  139. @validate_hf_hub_args
  140. def _fetch_xet_connection_info_with_url(
  141. url: str,
  142. headers: Dict[str, str],
  143. params: Optional[Dict[str, str]] = None,
  144. ) -> XetConnectionInfo:
  145. """
  146. Requests the xet connection info from the supplied URL. This includes the
  147. access token, expiration time, and endpoint to use for the xet storage service.
  148. Args:
  149. url: (`str`):
  150. The access token endpoint URL.
  151. headers (`Dict[str, str]`):
  152. Headers to use for the request, including authorization headers and user agent.
  153. params (`Dict[str, str]`, `optional`):
  154. Additional parameters to pass with the request.
  155. Returns:
  156. `XetConnectionInfo`:
  157. The connection information needed to make the request to the xet storage service.
  158. Raises:
  159. [`~utils.HfHubHTTPError`]
  160. If the Hub API returned an error.
  161. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  162. If the Hub API response is improperly formatted.
  163. """
  164. resp = get_session().get(headers=headers, url=url, params=params)
  165. hf_raise_for_status(resp)
  166. metadata = parse_xet_connection_info_from_headers(resp.headers) # type: ignore
  167. if metadata is None:
  168. raise ValueError("Xet headers have not been correctly set by the server.")
  169. return metadata