| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630 |
- # coding=utf-8
- # Copyright 2022-present, the HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Contains utilities to handle HTTP requests in Huggingface Hub."""
- import io
- import os
- import re
- import threading
- import time
- import uuid
- from functools import lru_cache
- from shlex import quote
- from typing import Any, Callable, List, Optional, Tuple, Type, Union
- import requests
- from requests import HTTPError, Response
- from requests.adapters import HTTPAdapter
- from requests.models import PreparedRequest
- from huggingface_hub.errors import OfflineModeIsEnabled
- from .. import constants
- from ..errors import (
- BadRequestError,
- DisabledRepoError,
- EntryNotFoundError,
- GatedRepoError,
- HfHubHTTPError,
- RepositoryNotFoundError,
- RevisionNotFoundError,
- )
- from . import logging
- from ._fixes import JSONDecodeError
- from ._lfs import SliceFileObj
- from ._typing import HTTP_METHOD_T
- logger = logging.get_logger(__name__)
- # Both headers are used by the Hub to debug failed requests.
- # `X_AMZN_TRACE_ID` is better as it also works to debug on Cloudfront and ALB.
- # If `X_AMZN_TRACE_ID` is set, the Hub will use it as well.
- X_AMZN_TRACE_ID = "X-Amzn-Trace-Id"
- X_REQUEST_ID = "x-request-id"
- REPO_API_REGEX = re.compile(
- r"""
- # staging or production endpoint
- ^https://[^/]+
- (
- # on /api/repo_type/repo_id
- /api/(models|datasets|spaces)/(.+)
- |
- # or /repo_id/resolve/revision/...
- /(.+)/resolve/(.+)
- )
- """,
- flags=re.VERBOSE,
- )
- class UniqueRequestIdAdapter(HTTPAdapter):
- X_AMZN_TRACE_ID = "X-Amzn-Trace-Id"
- def add_headers(self, request, **kwargs):
- super().add_headers(request, **kwargs)
- # Add random request ID => easier for server-side debug
- if X_AMZN_TRACE_ID not in request.headers:
- request.headers[X_AMZN_TRACE_ID] = request.headers.get(X_REQUEST_ID) or str(uuid.uuid4())
- # Add debug log
- has_token = len(str(request.headers.get("authorization", ""))) > 0
- logger.debug(
- f"Request {request.headers[X_AMZN_TRACE_ID]}: {request.method} {request.url} (authenticated: {has_token})"
- )
- def send(self, request: PreparedRequest, *args, **kwargs) -> Response:
- """Catch any RequestException to append request id to the error message for debugging."""
- if constants.HF_DEBUG:
- logger.debug(f"Send: {_curlify(request)}")
- try:
- return super().send(request, *args, **kwargs)
- except requests.RequestException as e:
- request_id = request.headers.get(X_AMZN_TRACE_ID)
- if request_id is not None:
- # Taken from https://stackoverflow.com/a/58270258
- e.args = (*e.args, f"(Request ID: {request_id})")
- raise
- class OfflineAdapter(HTTPAdapter):
- def send(self, request: PreparedRequest, *args, **kwargs) -> Response:
- raise OfflineModeIsEnabled(
- f"Cannot reach {request.url}: offline mode is enabled. To disable it, please unset the `HF_HUB_OFFLINE` environment variable."
- )
- def _default_backend_factory() -> requests.Session:
- session = requests.Session()
- if constants.HF_HUB_OFFLINE:
- session.mount("http://", OfflineAdapter())
- session.mount("https://", OfflineAdapter())
- else:
- session.mount("http://", UniqueRequestIdAdapter())
- session.mount("https://", UniqueRequestIdAdapter())
- return session
- BACKEND_FACTORY_T = Callable[[], requests.Session]
- _GLOBAL_BACKEND_FACTORY: BACKEND_FACTORY_T = _default_backend_factory
- def configure_http_backend(backend_factory: BACKEND_FACTORY_T = _default_backend_factory) -> None:
- """
- Configure the HTTP backend by providing a `backend_factory`. Any HTTP calls made by `huggingface_hub` will use a
- Session object instantiated by this factory. This can be useful if you are running your scripts in a specific
- environment requiring custom configuration (e.g. custom proxy or certifications).
- Use [`get_session`] to get a configured Session. Since `requests.Session` is not guaranteed to be thread-safe,
- `huggingface_hub` creates 1 Session instance per thread. They are all instantiated using the same `backend_factory`
- set in [`configure_http_backend`]. A LRU cache is used to cache the created sessions (and connections) between
- calls. Max size is 128 to avoid memory leaks if thousands of threads are spawned.
- See [this issue](https://github.com/psf/requests/issues/2766) to know more about thread-safety in `requests`.
- Example:
- ```py
- import requests
- from huggingface_hub import configure_http_backend, get_session
- # Create a factory function that returns a Session with configured proxies
- def backend_factory() -> requests.Session:
- session = requests.Session()
- session.proxies = {"http": "http://10.10.1.10:3128", "https": "https://10.10.1.11:1080"}
- return session
- # Set it as the default session factory
- configure_http_backend(backend_factory=backend_factory)
- # In practice, this is mostly done internally in `huggingface_hub`
- session = get_session()
- ```
- """
- global _GLOBAL_BACKEND_FACTORY
- _GLOBAL_BACKEND_FACTORY = backend_factory
- reset_sessions()
- def get_session() -> requests.Session:
- """
- Get a `requests.Session` object, using the session factory from the user.
- Use [`get_session`] to get a configured Session. Since `requests.Session` is not guaranteed to be thread-safe,
- `huggingface_hub` creates 1 Session instance per thread. They are all instantiated using the same `backend_factory`
- set in [`configure_http_backend`]. A LRU cache is used to cache the created sessions (and connections) between
- calls. Max size is 128 to avoid memory leaks if thousands of threads are spawned.
- See [this issue](https://github.com/psf/requests/issues/2766) to know more about thread-safety in `requests`.
- Example:
- ```py
- import requests
- from huggingface_hub import configure_http_backend, get_session
- # Create a factory function that returns a Session with configured proxies
- def backend_factory() -> requests.Session:
- session = requests.Session()
- session.proxies = {"http": "http://10.10.1.10:3128", "https": "https://10.10.1.11:1080"}
- return session
- # Set it as the default session factory
- configure_http_backend(backend_factory=backend_factory)
- # In practice, this is mostly done internally in `huggingface_hub`
- session = get_session()
- ```
- """
- return _get_session_from_cache(process_id=os.getpid(), thread_id=threading.get_ident())
- def reset_sessions() -> None:
- """Reset the cache of sessions.
- Mostly used internally when sessions are reconfigured or an SSLError is raised.
- See [`configure_http_backend`] for more details.
- """
- _get_session_from_cache.cache_clear()
- @lru_cache
- def _get_session_from_cache(process_id: int, thread_id: int) -> requests.Session:
- """
- Create a new session per thread using global factory. Using LRU cache (maxsize 128) to avoid memory leaks when
- using thousands of threads. Cache is cleared when `configure_http_backend` is called.
- """
- return _GLOBAL_BACKEND_FACTORY()
- def http_backoff(
- method: HTTP_METHOD_T,
- url: str,
- *,
- max_retries: int = 5,
- base_wait_time: float = 1,
- max_wait_time: float = 8,
- retry_on_exceptions: Union[Type[Exception], Tuple[Type[Exception], ...]] = (
- requests.Timeout,
- requests.ConnectionError,
- requests.exceptions.ChunkedEncodingError,
- ),
- retry_on_status_codes: Union[int, Tuple[int, ...]] = (500, 502, 503, 504),
- **kwargs,
- ) -> Response:
- """Wrapper around requests to retry calls on an endpoint, with exponential backoff.
- Endpoint call is retried on exceptions (ex: connection timeout, proxy error,...)
- and/or on specific status codes (ex: service unavailable). If the call failed more
- than `max_retries`, the exception is thrown or `raise_for_status` is called on the
- response object.
- Re-implement mechanisms from the `backoff` library to avoid adding an external
- dependencies to `hugging_face_hub`. See https://github.com/litl/backoff.
- Args:
- method (`Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]`):
- HTTP method to perform.
- url (`str`):
- The URL of the resource to fetch.
- max_retries (`int`, *optional*, defaults to `5`):
- Maximum number of retries, defaults to 5 (no retries).
- base_wait_time (`float`, *optional*, defaults to `1`):
- Duration (in seconds) to wait before retrying the first time.
- Wait time between retries then grows exponentially, capped by
- `max_wait_time`.
- max_wait_time (`float`, *optional*, defaults to `8`):
- Maximum duration (in seconds) to wait before retrying.
- retry_on_exceptions (`Type[Exception]` or `Tuple[Type[Exception]]`, *optional*):
- Define which exceptions must be caught to retry the request. Can be a single type or a tuple of types.
- By default, retry on `requests.Timeout`, `requests.ConnectionError` and `requests.exceptions.ChunkedEncodingError`.
- retry_on_status_codes (`int` or `Tuple[int]`, *optional*, defaults to `(500, 502, 503, 504)`):
- Define on which status codes the request must be retried. By default, 5xx errors are retried.
- **kwargs (`dict`, *optional*):
- kwargs to pass to `requests.request`.
- Example:
- ```
- >>> from huggingface_hub.utils import http_backoff
- # Same usage as "requests.request".
- >>> response = http_backoff("GET", "https://www.google.com")
- >>> response.raise_for_status()
- # If you expect a Gateway Timeout from time to time
- >>> http_backoff("PUT", upload_url, data=data, retry_on_status_codes=504)
- >>> response.raise_for_status()
- ```
- > [!WARNING]
- > When using `requests` it is possible to stream data by passing an iterator to the
- > `data` argument. On http backoff this is a problem as the iterator is not reset
- > after a failed call. This issue is mitigated for file objects or any IO streams
- > by saving the initial position of the cursor (with `data.tell()`) and resetting the
- > cursor between each call (with `data.seek()`). For arbitrary iterators, http backoff
- > will fail. If this is a hard constraint for you, please let us know by opening an
- > issue on [Github](https://github.com/huggingface/huggingface_hub).
- """
- if isinstance(retry_on_exceptions, type): # Tuple from single exception type
- retry_on_exceptions = (retry_on_exceptions,)
- if isinstance(retry_on_status_codes, int): # Tuple from single status code
- retry_on_status_codes = (retry_on_status_codes,)
- nb_tries = 0
- sleep_time = base_wait_time
- # If `data` is used and is a file object (or any IO), it will be consumed on the
- # first HTTP request. We need to save the initial position so that the full content
- # of the file is re-sent on http backoff. See warning tip in docstring.
- io_obj_initial_pos = None
- if "data" in kwargs and isinstance(kwargs["data"], (io.IOBase, SliceFileObj)):
- io_obj_initial_pos = kwargs["data"].tell()
- session = get_session()
- while True:
- nb_tries += 1
- try:
- # If `data` is used and is a file object (or any IO), set back cursor to
- # initial position.
- if io_obj_initial_pos is not None:
- kwargs["data"].seek(io_obj_initial_pos)
- # Perform request and return if status_code is not in the retry list.
- response = session.request(method=method, url=url, **kwargs)
- if response.status_code not in retry_on_status_codes:
- return response
- # Wrong status code returned (HTTP 503 for instance)
- logger.warning(f"HTTP Error {response.status_code} thrown while requesting {method} {url}")
- if nb_tries > max_retries:
- response.raise_for_status() # Will raise uncaught exception
- # We return response to avoid infinite loop in the corner case where the
- # user ask for retry on a status code that doesn't raise_for_status.
- return response
- except retry_on_exceptions as err:
- logger.warning(f"'{err}' thrown while requesting {method} {url}")
- if isinstance(err, requests.ConnectionError):
- reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects
- if nb_tries > max_retries:
- raise err
- # Sleep for X seconds
- logger.warning(f"Retrying in {sleep_time}s [Retry {nb_tries}/{max_retries}].")
- time.sleep(sleep_time)
- # Update sleep time for next retry
- sleep_time = min(max_wait_time, sleep_time * 2) # Exponential backoff
- def fix_hf_endpoint_in_url(url: str, endpoint: Optional[str]) -> str:
- """Replace the default endpoint in a URL by a custom one.
- This is useful when using a proxy and the Hugging Face Hub returns a URL with the default endpoint.
- """
- endpoint = endpoint.rstrip("/") if endpoint else constants.ENDPOINT
- # check if a proxy has been set => if yes, update the returned URL to use the proxy
- if endpoint not in (constants._HF_DEFAULT_ENDPOINT, constants._HF_DEFAULT_STAGING_ENDPOINT):
- url = url.replace(constants._HF_DEFAULT_ENDPOINT, endpoint)
- url = url.replace(constants._HF_DEFAULT_STAGING_ENDPOINT, endpoint)
- return url
- def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None) -> None:
- """
- Internal version of `response.raise_for_status()` that will refine a
- potential HTTPError. Raised exception will be an instance of `HfHubHTTPError`.
- This helper is meant to be the unique method to raise_for_status when making a call
- to the Hugging Face Hub.
- Example:
- ```py
- import requests
- from huggingface_hub.utils import get_session, hf_raise_for_status, HfHubHTTPError
- response = get_session().post(...)
- try:
- hf_raise_for_status(response)
- except HfHubHTTPError as e:
- print(str(e)) # formatted message
- e.request_id, e.server_message # details returned by server
- # Complete the error message with additional information once it's raised
- e.append_to_message("\n`create_commit` expects the repository to exist.")
- raise
- ```
- Args:
- response (`Response`):
- Response from the server.
- endpoint_name (`str`, *optional*):
- Name of the endpoint that has been called. If provided, the error message
- will be more complete.
- > [!WARNING]
- > Raises when the request has failed:
- >
- > - [`~utils.RepositoryNotFoundError`]
- > If the repository to download from cannot be found. This may be because it
- > doesn't exist, because `repo_type` is not set correctly, or because the repo
- > is `private` and you do not have access.
- > - [`~utils.GatedRepoError`]
- > If the repository exists but is gated and the user is not on the authorized
- > list.
- > - [`~utils.RevisionNotFoundError`]
- > If the repository exists but the revision couldn't be find.
- > - [`~utils.EntryNotFoundError`]
- > If the repository exists but the entry (e.g. the requested file) couldn't be
- > find.
- > - [`~utils.BadRequestError`]
- > If request failed with a HTTP 400 BadRequest error.
- > - [`~utils.HfHubHTTPError`]
- > If request failed for a reason not listed above.
- """
- try:
- response.raise_for_status()
- except HTTPError as e:
- error_code = response.headers.get("X-Error-Code")
- error_message = response.headers.get("X-Error-Message")
- if error_code == "RevisionNotFound":
- message = f"{response.status_code} Client Error." + "\n\n" + f"Revision Not Found for url: {response.url}."
- raise _format(RevisionNotFoundError, message, response) from e
- elif error_code == "EntryNotFound":
- message = f"{response.status_code} Client Error." + "\n\n" + f"Entry Not Found for url: {response.url}."
- raise _format(EntryNotFoundError, message, response) from e
- elif error_code == "GatedRepo":
- message = (
- f"{response.status_code} Client Error." + "\n\n" + f"Cannot access gated repo for url {response.url}."
- )
- raise _format(GatedRepoError, message, response) from e
- elif error_message == "Access to this resource is disabled.":
- message = (
- f"{response.status_code} Client Error."
- + "\n\n"
- + f"Cannot access repository for url {response.url}."
- + "\n"
- + "Access to this resource is disabled."
- )
- raise _format(DisabledRepoError, message, response) from e
- elif error_code == "RepoNotFound" or (
- response.status_code == 401
- and error_message != "Invalid credentials in Authorization header"
- and response.request is not None
- and response.request.url is not None
- and REPO_API_REGEX.search(response.request.url) is not None
- ):
- # 401 is misleading as it is returned for:
- # - private and gated repos if user is not authenticated
- # - missing repos
- # => for now, we process them as `RepoNotFound` anyway.
- # See https://gist.github.com/Wauplin/46c27ad266b15998ce56a6603796f0b9
- message = (
- f"{response.status_code} Client Error."
- + "\n\n"
- + f"Repository Not Found for url: {response.url}."
- + "\nPlease make sure you specified the correct `repo_id` and"
- " `repo_type`.\nIf you are trying to access a private or gated repo,"
- " make sure you are authenticated. For more details, see"
- " https://huggingface.co/docs/huggingface_hub/authentication"
- )
- raise _format(RepositoryNotFoundError, message, response) from e
- elif response.status_code == 400:
- message = (
- f"\n\nBad request for {endpoint_name} endpoint:" if endpoint_name is not None else "\n\nBad request:"
- )
- raise _format(BadRequestError, message, response) from e
- elif response.status_code == 403:
- message = (
- f"\n\n{response.status_code} Forbidden: {error_message}."
- + f"\nCannot access content at: {response.url}."
- + "\nMake sure your token has the correct permissions."
- )
- raise _format(HfHubHTTPError, message, response) from e
- elif response.status_code == 416:
- range_header = response.request.headers.get("Range")
- message = f"{e}. Requested range: {range_header}. Content-Range: {response.headers.get('Content-Range')}."
- raise _format(HfHubHTTPError, message, response) from e
- # Convert `HTTPError` into a `HfHubHTTPError` to display request information
- # as well (request id and/or server error message)
- raise _format(HfHubHTTPError, str(e), response) from e
- def _format(error_type: Type[HfHubHTTPError], custom_message: str, response: Response) -> HfHubHTTPError:
- server_errors = []
- # Retrieve server error from header
- from_headers = response.headers.get("X-Error-Message")
- if from_headers is not None:
- server_errors.append(from_headers)
- # Retrieve server error from body
- try:
- # Case errors are returned in a JSON format
- data = response.json()
- error = data.get("error")
- if error is not None:
- if isinstance(error, list):
- # Case {'error': ['my error 1', 'my error 2']}
- server_errors.extend(error)
- else:
- # Case {'error': 'my error'}
- server_errors.append(error)
- errors = data.get("errors")
- if errors is not None:
- # Case {'errors': [{'message': 'my error 1'}, {'message': 'my error 2'}]}
- for error in errors:
- if "message" in error:
- server_errors.append(error["message"])
- except JSONDecodeError:
- # If content is not JSON and not HTML, append the text
- content_type = response.headers.get("Content-Type", "")
- if response.text and "html" not in content_type.lower():
- server_errors.append(response.text)
- # Strip all server messages
- server_errors = [str(line).strip() for line in server_errors if str(line).strip()]
- # Deduplicate server messages (keep order)
- # taken from https://stackoverflow.com/a/17016257
- server_errors = list(dict.fromkeys(server_errors))
- # Format server error
- server_message = "\n".join(server_errors)
- # Add server error to custom message
- final_error_message = custom_message
- if server_message and server_message.lower() not in custom_message.lower():
- if "\n\n" in custom_message:
- final_error_message += "\n" + server_message
- else:
- final_error_message += "\n\n" + server_message
- # Add Request ID
- request_id = str(response.headers.get(X_REQUEST_ID, ""))
- if request_id:
- request_id_message = f" (Request ID: {request_id})"
- else:
- # Fallback to X-Amzn-Trace-Id
- request_id = str(response.headers.get(X_AMZN_TRACE_ID, ""))
- if request_id:
- request_id_message = f" (Amzn Trace ID: {request_id})"
- if request_id and request_id.lower() not in final_error_message.lower():
- if "\n" in final_error_message:
- newline_index = final_error_message.index("\n")
- final_error_message = (
- final_error_message[:newline_index] + request_id_message + final_error_message[newline_index:]
- )
- else:
- final_error_message += request_id_message
- # Return
- return error_type(final_error_message.strip(), response=response, server_message=server_message or None)
- def _curlify(request: requests.PreparedRequest) -> str:
- """Convert a `requests.PreparedRequest` into a curl command (str).
- Used for debug purposes only.
- Implementation vendored from https://github.com/ofw/curlify/blob/master/curlify.py.
- MIT License Copyright (c) 2016 Egor.
- """
- parts: List[Tuple[Any, Any]] = [
- ("curl", None),
- ("-X", request.method),
- ]
- for k, v in sorted(request.headers.items()):
- if k.lower() == "authorization":
- v = "<TOKEN>" # Hide authorization header, no matter its value (can be Bearer, Key, etc.)
- parts += [("-H", "{0}: {1}".format(k, v))]
- if request.body:
- body = request.body
- if isinstance(body, bytes):
- body = body.decode("utf-8", errors="ignore")
- elif hasattr(body, "read"):
- body = "<file-like object>" # Don't try to read it to avoid consuming the stream
- if len(body) > 1000:
- body = body[:1000] + " ... [truncated]"
- parts += [("-d", body.replace("\n", ""))]
- parts += [(None, request.url)]
- flat_parts = []
- for k, v in parts:
- if k:
- flat_parts.append(quote(k))
- if v:
- flat_parts.append(quote(v))
- return " ".join(flat_parts)
- # Regex to parse HTTP Range header
- RANGE_REGEX = re.compile(r"^\s*bytes\s*=\s*(\d*)\s*-\s*(\d*)\s*$", re.IGNORECASE)
- def _adjust_range_header(original_range: Optional[str], resume_size: int) -> Optional[str]:
- """
- Adjust HTTP Range header to account for resume position.
- """
- if not original_range:
- return f"bytes={resume_size}-"
- if "," in original_range:
- raise ValueError(f"Multiple ranges detected - {original_range!r}, not supported yet.")
- match = RANGE_REGEX.match(original_range)
- if not match:
- raise RuntimeError(f"Invalid range format - {original_range!r}.")
- start, end = match.groups()
- if not start:
- if not end:
- raise RuntimeError(f"Invalid range format - {original_range!r}.")
- new_suffix = int(end) - resume_size
- new_range = f"bytes=-{new_suffix}"
- if new_suffix <= 0:
- raise RuntimeError(f"Empty new range - {new_range!r}.")
- return new_range
- start = int(start)
- new_start = start + resume_size
- if end:
- end = int(end)
- new_range = f"bytes={new_start}-{end}"
- if new_start > end:
- raise RuntimeError(f"Empty new range - {new_range!r}.")
- return new_range
- return f"bytes={new_start}-"
|