inference_api.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import io
  2. from typing import Any, Dict, List, Optional, Union
  3. from . import constants
  4. from .hf_api import HfApi
  5. from .utils import build_hf_headers, get_session, is_pillow_available, logging, validate_hf_hub_args
  6. from .utils._deprecation import _deprecate_method
  7. logger = logging.get_logger(__name__)
  8. ALL_TASKS = [
  9. # NLP
  10. "text-classification",
  11. "token-classification",
  12. "table-question-answering",
  13. "question-answering",
  14. "zero-shot-classification",
  15. "translation",
  16. "summarization",
  17. "conversational",
  18. "feature-extraction",
  19. "text-generation",
  20. "text2text-generation",
  21. "fill-mask",
  22. "sentence-similarity",
  23. # Audio
  24. "text-to-speech",
  25. "automatic-speech-recognition",
  26. "audio-to-audio",
  27. "audio-classification",
  28. "voice-activity-detection",
  29. # Computer vision
  30. "image-classification",
  31. "object-detection",
  32. "image-segmentation",
  33. "text-to-image",
  34. "image-to-image",
  35. # Others
  36. "tabular-classification",
  37. "tabular-regression",
  38. ]
  39. class InferenceApi:
  40. """Client to configure requests and make calls to the HuggingFace Inference API.
  41. Example:
  42. ```python
  43. >>> from huggingface_hub.inference_api import InferenceApi
  44. >>> # Mask-fill example
  45. >>> inference = InferenceApi("bert-base-uncased")
  46. >>> inference(inputs="The goal of life is [MASK].")
  47. [{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}]
  48. >>> # Question Answering example
  49. >>> inference = InferenceApi("deepset/roberta-base-squad2")
  50. >>> inputs = {
  51. ... "question": "What's my name?",
  52. ... "context": "My name is Clara and I live in Berkeley.",
  53. ... }
  54. >>> inference(inputs)
  55. {'score': 0.9326569437980652, 'start': 11, 'end': 16, 'answer': 'Clara'}
  56. >>> # Zero-shot example
  57. >>> inference = InferenceApi("typeform/distilbert-base-uncased-mnli")
  58. >>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!"
  59. >>> params = {"candidate_labels": ["refund", "legal", "faq"]}
  60. >>> inference(inputs, params)
  61. {'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]}
  62. >>> # Overriding configured task
  63. >>> inference = InferenceApi("bert-base-uncased", task="feature-extraction")
  64. >>> # Text-to-image
  65. >>> inference = InferenceApi("stabilityai/stable-diffusion-2-1")
  66. >>> inference("cat")
  67. <PIL.PngImagePlugin.PngImageFile image (...)>
  68. >>> # Return as raw response to parse the output yourself
  69. >>> inference = InferenceApi("mio/amadeus")
  70. >>> response = inference("hello world", raw_response=True)
  71. >>> response.headers
  72. {"Content-Type": "audio/flac", ...}
  73. >>> response.content # raw bytes from server
  74. b'(...)'
  75. ```
  76. """
  77. @validate_hf_hub_args
  78. @_deprecate_method(
  79. version="1.0",
  80. message=(
  81. "`InferenceApi` client is deprecated in favor of the more feature-complete `InferenceClient`. Check out"
  82. " this guide to learn how to convert your script to use it:"
  83. " https://huggingface.co/docs/huggingface_hub/guides/inference#legacy-inferenceapi-client."
  84. ),
  85. )
  86. def __init__(
  87. self,
  88. repo_id: str,
  89. task: Optional[str] = None,
  90. token: Optional[str] = None,
  91. gpu: bool = False,
  92. ):
  93. """Inits headers and API call information.
  94. Args:
  95. repo_id (``str``):
  96. Id of repository (e.g. `user/bert-base-uncased`).
  97. task (``str``, `optional`, defaults ``None``):
  98. Whether to force a task instead of using task specified in the
  99. repository.
  100. token (`str`, `optional`):
  101. The API token to use as HTTP bearer authorization. This is not
  102. the authentication token. You can find the token in
  103. https://huggingface.co/settings/token. Alternatively, you can
  104. find both your organizations and personal API tokens using
  105. `HfApi().whoami(token)`.
  106. gpu (`bool`, `optional`, defaults `False`):
  107. Whether to use GPU instead of CPU for inference(requires Startup
  108. plan at least).
  109. """
  110. self.options = {"wait_for_model": True, "use_gpu": gpu}
  111. self.headers = build_hf_headers(token=token)
  112. # Configure task
  113. model_info = HfApi(token=token).model_info(repo_id=repo_id)
  114. if not model_info.pipeline_tag and not task:
  115. raise ValueError(
  116. "Task not specified in the repository. Please add it to the model card"
  117. " using pipeline_tag"
  118. " (https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined)"
  119. )
  120. if task and task != model_info.pipeline_tag:
  121. if task not in ALL_TASKS:
  122. raise ValueError(f"Invalid task {task}. Make sure it's valid.")
  123. logger.warning(
  124. "You're using a different task than the one specified in the"
  125. " repository. Be sure to know what you're doing :)"
  126. )
  127. self.task = task
  128. else:
  129. assert model_info.pipeline_tag is not None, "Pipeline tag cannot be None"
  130. self.task = model_info.pipeline_tag
  131. self.api_url = f"{constants.INFERENCE_ENDPOINT}/pipeline/{self.task}/{repo_id}"
  132. def __repr__(self):
  133. # Do not add headers to repr to avoid leaking token.
  134. return f"InferenceAPI(api_url='{self.api_url}', task='{self.task}', options={self.options})"
  135. def __call__(
  136. self,
  137. inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
  138. params: Optional[Dict] = None,
  139. data: Optional[bytes] = None,
  140. raw_response: bool = False,
  141. ) -> Any:
  142. """Make a call to the Inference API.
  143. Args:
  144. inputs (`str` or `Dict` or `List[str]` or `List[List[str]]`, *optional*):
  145. Inputs for the prediction.
  146. params (`Dict`, *optional*):
  147. Additional parameters for the models. Will be sent as `parameters` in the
  148. payload.
  149. data (`bytes`, *optional*):
  150. Bytes content of the request. In this case, leave `inputs` and `params` empty.
  151. raw_response (`bool`, defaults to `False`):
  152. If `True`, the raw `Response` object is returned. You can parse its content
  153. as preferred. By default, the content is parsed into a more practical format
  154. (json dictionary or PIL Image for example).
  155. """
  156. # Build payload
  157. payload: Dict[str, Any] = {
  158. "options": self.options,
  159. }
  160. if inputs:
  161. payload["inputs"] = inputs
  162. if params:
  163. payload["parameters"] = params
  164. # Make API call
  165. response = get_session().post(self.api_url, headers=self.headers, json=payload, data=data)
  166. # Let the user handle the response
  167. if raw_response:
  168. return response
  169. # By default, parse the response for the user.
  170. content_type = response.headers.get("Content-Type") or ""
  171. if content_type.startswith("image"):
  172. if not is_pillow_available():
  173. raise ImportError(
  174. f"Task '{self.task}' returned as image but Pillow is not installed."
  175. " Please install it (`pip install Pillow`) or pass"
  176. " `raw_response=True` to get the raw `Response` object and parse"
  177. " the image by yourself."
  178. )
  179. from PIL import Image
  180. return Image.open(io.BytesIO(response.content))
  181. elif content_type == "application/json":
  182. return response.json()
  183. else:
  184. raise NotImplementedError(
  185. f"{content_type} output type is not implemented yet. You can pass"
  186. " `raw_response=True` to get the raw `Response` object and parse the"
  187. " output by yourself."
  188. )