azure.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647
  1. from __future__ import annotations
  2. import os
  3. import inspect
  4. from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload
  5. from typing_extensions import Self, override
  6. import httpx
  7. from .._types import NOT_GIVEN, Omit, Query, Timeout, NotGiven
  8. from .._utils import is_given, is_mapping
  9. from .._client import OpenAI, AsyncOpenAI
  10. from .._compat import model_copy
  11. from .._models import FinalRequestOptions
  12. from .._streaming import Stream, AsyncStream
  13. from .._exceptions import OpenAIError
  14. from .._base_client import DEFAULT_MAX_RETRIES, BaseClient
  15. _deployments_endpoints = set(
  16. [
  17. "/completions",
  18. "/chat/completions",
  19. "/embeddings",
  20. "/audio/transcriptions",
  21. "/audio/translations",
  22. "/audio/speech",
  23. "/images/generations",
  24. "/images/edits",
  25. ]
  26. )
  27. AzureADTokenProvider = Callable[[], str]
  28. AsyncAzureADTokenProvider = Callable[[], "str | Awaitable[str]"]
  29. _HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient])
  30. _DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])
  31. # we need to use a sentinel API key value for Azure AD
  32. # as we don't want to make the `api_key` in the main client Optional
  33. # and Azure AD tokens may be retrieved on a per-request basis
  34. API_KEY_SENTINEL = "".join(["<", "missing API key", ">"])
  35. class MutuallyExclusiveAuthError(OpenAIError):
  36. def __init__(self) -> None:
  37. super().__init__(
  38. "The `api_key`, `azure_ad_token` and `azure_ad_token_provider` arguments are mutually exclusive; Only one can be passed at a time"
  39. )
  40. class BaseAzureClient(BaseClient[_HttpxClientT, _DefaultStreamT]):
  41. _azure_endpoint: httpx.URL | None
  42. _azure_deployment: str | None
  43. @override
  44. def _build_request(
  45. self,
  46. options: FinalRequestOptions,
  47. *,
  48. retries_taken: int = 0,
  49. ) -> httpx.Request:
  50. if options.url in _deployments_endpoints and is_mapping(options.json_data):
  51. model = options.json_data.get("model")
  52. if model is not None and "/deployments" not in str(self.base_url.path):
  53. options.url = f"/deployments/{model}{options.url}"
  54. return super()._build_request(options, retries_taken=retries_taken)
  55. @override
  56. def _prepare_url(self, url: str) -> httpx.URL:
  57. """Adjust the URL if the client was configured with an Azure endpoint + deployment
  58. and the API feature being called is **not** a deployments-based endpoint
  59. (i.e. requires /deployments/deployment-name in the URL path).
  60. """
  61. if self._azure_deployment and self._azure_endpoint and url not in _deployments_endpoints:
  62. merge_url = httpx.URL(url)
  63. if merge_url.is_relative_url:
  64. merge_raw_path = (
  65. self._azure_endpoint.raw_path.rstrip(b"/") + b"/openai/" + merge_url.raw_path.lstrip(b"/")
  66. )
  67. return self._azure_endpoint.copy_with(raw_path=merge_raw_path)
  68. return merge_url
  69. return super()._prepare_url(url)
  70. class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
  71. @overload
  72. def __init__(
  73. self,
  74. *,
  75. azure_endpoint: str,
  76. azure_deployment: str | None = None,
  77. api_version: str | None = None,
  78. api_key: str | Callable[[], str] | None = None,
  79. azure_ad_token: str | None = None,
  80. azure_ad_token_provider: AzureADTokenProvider | None = None,
  81. organization: str | None = None,
  82. webhook_secret: str | None = None,
  83. websocket_base_url: str | httpx.URL | None = None,
  84. timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
  85. max_retries: int = DEFAULT_MAX_RETRIES,
  86. default_headers: Mapping[str, str] | None = None,
  87. default_query: Mapping[str, object] | None = None,
  88. http_client: httpx.Client | None = None,
  89. _strict_response_validation: bool = False,
  90. ) -> None: ...
  91. @overload
  92. def __init__(
  93. self,
  94. *,
  95. azure_deployment: str | None = None,
  96. api_version: str | None = None,
  97. api_key: str | Callable[[], str] | None = None,
  98. azure_ad_token: str | None = None,
  99. azure_ad_token_provider: AzureADTokenProvider | None = None,
  100. organization: str | None = None,
  101. webhook_secret: str | None = None,
  102. websocket_base_url: str | httpx.URL | None = None,
  103. timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
  104. max_retries: int = DEFAULT_MAX_RETRIES,
  105. default_headers: Mapping[str, str] | None = None,
  106. default_query: Mapping[str, object] | None = None,
  107. http_client: httpx.Client | None = None,
  108. _strict_response_validation: bool = False,
  109. ) -> None: ...
  110. @overload
  111. def __init__(
  112. self,
  113. *,
  114. base_url: str,
  115. api_version: str | None = None,
  116. api_key: str | Callable[[], str] | None = None,
  117. azure_ad_token: str | None = None,
  118. azure_ad_token_provider: AzureADTokenProvider | None = None,
  119. organization: str | None = None,
  120. webhook_secret: str | None = None,
  121. websocket_base_url: str | httpx.URL | None = None,
  122. timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
  123. max_retries: int = DEFAULT_MAX_RETRIES,
  124. default_headers: Mapping[str, str] | None = None,
  125. default_query: Mapping[str, object] | None = None,
  126. http_client: httpx.Client | None = None,
  127. _strict_response_validation: bool = False,
  128. ) -> None: ...
  129. def __init__(
  130. self,
  131. *,
  132. api_version: str | None = None,
  133. azure_endpoint: str | None = None,
  134. azure_deployment: str | None = None,
  135. api_key: str | Callable[[], str] | None = None,
  136. azure_ad_token: str | None = None,
  137. azure_ad_token_provider: AzureADTokenProvider | None = None,
  138. organization: str | None = None,
  139. project: str | None = None,
  140. webhook_secret: str | None = None,
  141. websocket_base_url: str | httpx.URL | None = None,
  142. base_url: str | None = None,
  143. timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
  144. max_retries: int = DEFAULT_MAX_RETRIES,
  145. default_headers: Mapping[str, str] | None = None,
  146. default_query: Mapping[str, object] | None = None,
  147. http_client: httpx.Client | None = None,
  148. _strict_response_validation: bool = False,
  149. ) -> None:
  150. """Construct a new synchronous azure openai client instance.
  151. This automatically infers the following arguments from their corresponding environment variables if they are not provided:
  152. - `api_key` from `AZURE_OPENAI_API_KEY`
  153. - `organization` from `OPENAI_ORG_ID`
  154. - `project` from `OPENAI_PROJECT_ID`
  155. - `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
  156. - `api_version` from `OPENAI_API_VERSION`
  157. - `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
  158. Args:
  159. azure_endpoint: Your Azure endpoint, including the resource, e.g. `https://example-resource.azure.openai.com/`
  160. azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
  161. azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
  162. azure_deployment: A model deployment, if given with `azure_endpoint`, sets the base client URL to include `/deployments/{azure_deployment}`.
  163. Not supported with Assistants APIs.
  164. """
  165. if api_key is None:
  166. api_key = os.environ.get("AZURE_OPENAI_API_KEY")
  167. if azure_ad_token is None:
  168. azure_ad_token = os.environ.get("AZURE_OPENAI_AD_TOKEN")
  169. if api_key is None and azure_ad_token is None and azure_ad_token_provider is None:
  170. raise OpenAIError(
  171. "Missing credentials. Please pass one of `api_key`, `azure_ad_token`, `azure_ad_token_provider`, or the `AZURE_OPENAI_API_KEY` or `AZURE_OPENAI_AD_TOKEN` environment variables."
  172. )
  173. if api_version is None:
  174. api_version = os.environ.get("OPENAI_API_VERSION")
  175. if api_version is None:
  176. raise ValueError(
  177. "Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable"
  178. )
  179. if default_query is None:
  180. default_query = {"api-version": api_version}
  181. else:
  182. default_query = {**default_query, "api-version": api_version}
  183. if base_url is None:
  184. if azure_endpoint is None:
  185. azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
  186. if azure_endpoint is None:
  187. raise ValueError(
  188. "Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
  189. )
  190. if azure_deployment is not None:
  191. base_url = f"{azure_endpoint.rstrip('/')}/openai/deployments/{azure_deployment}"
  192. else:
  193. base_url = f"{azure_endpoint.rstrip('/')}/openai"
  194. else:
  195. if azure_endpoint is not None:
  196. raise ValueError("base_url and azure_endpoint are mutually exclusive")
  197. if api_key is None:
  198. # define a sentinel value to avoid any typing issues
  199. api_key = API_KEY_SENTINEL
  200. super().__init__(
  201. api_key=api_key,
  202. organization=organization,
  203. project=project,
  204. webhook_secret=webhook_secret,
  205. base_url=base_url,
  206. timeout=timeout,
  207. max_retries=max_retries,
  208. default_headers=default_headers,
  209. default_query=default_query,
  210. http_client=http_client,
  211. websocket_base_url=websocket_base_url,
  212. _strict_response_validation=_strict_response_validation,
  213. )
  214. self._api_version = api_version
  215. self._azure_ad_token = azure_ad_token
  216. self._azure_ad_token_provider = azure_ad_token_provider
  217. self._azure_deployment = azure_deployment if azure_endpoint else None
  218. self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None
  219. @override
  220. def copy(
  221. self,
  222. *,
  223. api_key: str | Callable[[], str] | None = None,
  224. organization: str | None = None,
  225. project: str | None = None,
  226. webhook_secret: str | None = None,
  227. websocket_base_url: str | httpx.URL | None = None,
  228. api_version: str | None = None,
  229. azure_ad_token: str | None = None,
  230. azure_ad_token_provider: AzureADTokenProvider | None = None,
  231. base_url: str | httpx.URL | None = None,
  232. timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
  233. http_client: httpx.Client | None = None,
  234. max_retries: int | NotGiven = NOT_GIVEN,
  235. default_headers: Mapping[str, str] | None = None,
  236. set_default_headers: Mapping[str, str] | None = None,
  237. default_query: Mapping[str, object] | None = None,
  238. set_default_query: Mapping[str, object] | None = None,
  239. _extra_kwargs: Mapping[str, Any] = {},
  240. ) -> Self:
  241. """
  242. Create a new client instance re-using the same options given to the current client with optional overriding.
  243. """
  244. return super().copy(
  245. api_key=api_key,
  246. organization=organization,
  247. project=project,
  248. webhook_secret=webhook_secret,
  249. websocket_base_url=websocket_base_url,
  250. base_url=base_url,
  251. timeout=timeout,
  252. http_client=http_client,
  253. max_retries=max_retries,
  254. default_headers=default_headers,
  255. set_default_headers=set_default_headers,
  256. default_query=default_query,
  257. set_default_query=set_default_query,
  258. _extra_kwargs={
  259. "api_version": api_version or self._api_version,
  260. "azure_ad_token": azure_ad_token or self._azure_ad_token,
  261. "azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider,
  262. **_extra_kwargs,
  263. },
  264. )
  265. with_options = copy
  266. def _get_azure_ad_token(self) -> str | None:
  267. if self._azure_ad_token is not None:
  268. return self._azure_ad_token
  269. provider = self._azure_ad_token_provider
  270. if provider is not None:
  271. token = provider()
  272. if not token or not isinstance(token, str): # pyright: ignore[reportUnnecessaryIsInstance]
  273. raise ValueError(
  274. f"Expected `azure_ad_token_provider` argument to return a string but it returned {token}",
  275. )
  276. return token
  277. return None
  278. @override
  279. def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
  280. headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}
  281. options = model_copy(options)
  282. options.headers = headers
  283. azure_ad_token = self._get_azure_ad_token()
  284. if azure_ad_token is not None:
  285. if headers.get("Authorization") is None:
  286. headers["Authorization"] = f"Bearer {azure_ad_token}"
  287. elif self.api_key is not API_KEY_SENTINEL:
  288. if headers.get("api-key") is None:
  289. headers["api-key"] = self.api_key
  290. else:
  291. # should never be hit
  292. raise ValueError("Unable to handle auth")
  293. return options
  294. def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL, dict[str, str]]:
  295. auth_headers = {}
  296. query = {
  297. **extra_query,
  298. "api-version": self._api_version,
  299. "deployment": self._azure_deployment or model,
  300. }
  301. if self.api_key and self.api_key != "<missing API key>":
  302. auth_headers = {"api-key": self.api_key}
  303. else:
  304. token = self._get_azure_ad_token()
  305. if token:
  306. auth_headers = {"Authorization": f"Bearer {token}"}
  307. if self.websocket_base_url is not None:
  308. base_url = httpx.URL(self.websocket_base_url)
  309. merge_raw_path = base_url.raw_path.rstrip(b"/") + b"/realtime"
  310. realtime_url = base_url.copy_with(raw_path=merge_raw_path)
  311. else:
  312. base_url = self._prepare_url("/realtime")
  313. realtime_url = base_url.copy_with(scheme="wss")
  314. url = realtime_url.copy_with(params={**query})
  315. return url, auth_headers
  316. class AsyncAzureOpenAI(BaseAzureClient[httpx.AsyncClient, AsyncStream[Any]], AsyncOpenAI):
  317. @overload
  318. def __init__(
  319. self,
  320. *,
  321. azure_endpoint: str,
  322. azure_deployment: str | None = None,
  323. api_version: str | None = None,
  324. api_key: str | Callable[[], Awaitable[str]] | None = None,
  325. azure_ad_token: str | None = None,
  326. azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
  327. organization: str | None = None,
  328. project: str | None = None,
  329. webhook_secret: str | None = None,
  330. websocket_base_url: str | httpx.URL | None = None,
  331. timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
  332. max_retries: int = DEFAULT_MAX_RETRIES,
  333. default_headers: Mapping[str, str] | None = None,
  334. default_query: Mapping[str, object] | None = None,
  335. http_client: httpx.AsyncClient | None = None,
  336. _strict_response_validation: bool = False,
  337. ) -> None: ...
  338. @overload
  339. def __init__(
  340. self,
  341. *,
  342. azure_deployment: str | None = None,
  343. api_version: str | None = None,
  344. api_key: str | Callable[[], Awaitable[str]] | None = None,
  345. azure_ad_token: str | None = None,
  346. azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
  347. organization: str | None = None,
  348. project: str | None = None,
  349. webhook_secret: str | None = None,
  350. websocket_base_url: str | httpx.URL | None = None,
  351. timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
  352. max_retries: int = DEFAULT_MAX_RETRIES,
  353. default_headers: Mapping[str, str] | None = None,
  354. default_query: Mapping[str, object] | None = None,
  355. http_client: httpx.AsyncClient | None = None,
  356. _strict_response_validation: bool = False,
  357. ) -> None: ...
  358. @overload
  359. def __init__(
  360. self,
  361. *,
  362. base_url: str,
  363. api_version: str | None = None,
  364. api_key: str | Callable[[], Awaitable[str]] | None = None,
  365. azure_ad_token: str | None = None,
  366. azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
  367. organization: str | None = None,
  368. project: str | None = None,
  369. webhook_secret: str | None = None,
  370. websocket_base_url: str | httpx.URL | None = None,
  371. timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
  372. max_retries: int = DEFAULT_MAX_RETRIES,
  373. default_headers: Mapping[str, str] | None = None,
  374. default_query: Mapping[str, object] | None = None,
  375. http_client: httpx.AsyncClient | None = None,
  376. _strict_response_validation: bool = False,
  377. ) -> None: ...
  378. def __init__(
  379. self,
  380. *,
  381. azure_endpoint: str | None = None,
  382. azure_deployment: str | None = None,
  383. api_version: str | None = None,
  384. api_key: str | Callable[[], Awaitable[str]] | None = None,
  385. azure_ad_token: str | None = None,
  386. azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
  387. organization: str | None = None,
  388. project: str | None = None,
  389. webhook_secret: str | None = None,
  390. base_url: str | None = None,
  391. websocket_base_url: str | httpx.URL | None = None,
  392. timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
  393. max_retries: int = DEFAULT_MAX_RETRIES,
  394. default_headers: Mapping[str, str] | None = None,
  395. default_query: Mapping[str, object] | None = None,
  396. http_client: httpx.AsyncClient | None = None,
  397. _strict_response_validation: bool = False,
  398. ) -> None:
  399. """Construct a new asynchronous azure openai client instance.
  400. This automatically infers the following arguments from their corresponding environment variables if they are not provided:
  401. - `api_key` from `AZURE_OPENAI_API_KEY`
  402. - `organization` from `OPENAI_ORG_ID`
  403. - `project` from `OPENAI_PROJECT_ID`
  404. - `azure_ad_token` from `AZURE_OPENAI_AD_TOKEN`
  405. - `api_version` from `OPENAI_API_VERSION`
  406. - `azure_endpoint` from `AZURE_OPENAI_ENDPOINT`
  407. Args:
  408. azure_endpoint: Your Azure endpoint, including the resource, e.g. `https://example-resource.azure.openai.com/`
  409. azure_ad_token: Your Azure Active Directory token, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id
  410. azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on every request.
  411. azure_deployment: A model deployment, if given with `azure_endpoint`, sets the base client URL to include `/deployments/{azure_deployment}`.
  412. Not supported with Assistants APIs.
  413. """
  414. if api_key is None:
  415. api_key = os.environ.get("AZURE_OPENAI_API_KEY")
  416. if azure_ad_token is None:
  417. azure_ad_token = os.environ.get("AZURE_OPENAI_AD_TOKEN")
  418. if api_key is None and azure_ad_token is None and azure_ad_token_provider is None:
  419. raise OpenAIError(
  420. "Missing credentials. Please pass one of `api_key`, `azure_ad_token`, `azure_ad_token_provider`, or the `AZURE_OPENAI_API_KEY` or `AZURE_OPENAI_AD_TOKEN` environment variables."
  421. )
  422. if api_version is None:
  423. api_version = os.environ.get("OPENAI_API_VERSION")
  424. if api_version is None:
  425. raise ValueError(
  426. "Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable"
  427. )
  428. if default_query is None:
  429. default_query = {"api-version": api_version}
  430. else:
  431. default_query = {**default_query, "api-version": api_version}
  432. if base_url is None:
  433. if azure_endpoint is None:
  434. azure_endpoint = os.environ.get("AZURE_OPENAI_ENDPOINT")
  435. if azure_endpoint is None:
  436. raise ValueError(
  437. "Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable"
  438. )
  439. if azure_deployment is not None:
  440. base_url = f"{azure_endpoint.rstrip('/')}/openai/deployments/{azure_deployment}"
  441. else:
  442. base_url = f"{azure_endpoint.rstrip('/')}/openai"
  443. else:
  444. if azure_endpoint is not None:
  445. raise ValueError("base_url and azure_endpoint are mutually exclusive")
  446. if api_key is None:
  447. # define a sentinel value to avoid any typing issues
  448. api_key = API_KEY_SENTINEL
  449. super().__init__(
  450. api_key=api_key,
  451. organization=organization,
  452. project=project,
  453. webhook_secret=webhook_secret,
  454. base_url=base_url,
  455. timeout=timeout,
  456. max_retries=max_retries,
  457. default_headers=default_headers,
  458. default_query=default_query,
  459. http_client=http_client,
  460. websocket_base_url=websocket_base_url,
  461. _strict_response_validation=_strict_response_validation,
  462. )
  463. self._api_version = api_version
  464. self._azure_ad_token = azure_ad_token
  465. self._azure_ad_token_provider = azure_ad_token_provider
  466. self._azure_deployment = azure_deployment if azure_endpoint else None
  467. self._azure_endpoint = httpx.URL(azure_endpoint) if azure_endpoint else None
  468. @override
  469. def copy(
  470. self,
  471. *,
  472. api_key: str | Callable[[], Awaitable[str]] | None = None,
  473. organization: str | None = None,
  474. project: str | None = None,
  475. webhook_secret: str | None = None,
  476. websocket_base_url: str | httpx.URL | None = None,
  477. api_version: str | None = None,
  478. azure_ad_token: str | None = None,
  479. azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
  480. base_url: str | httpx.URL | None = None,
  481. timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
  482. http_client: httpx.AsyncClient | None = None,
  483. max_retries: int | NotGiven = NOT_GIVEN,
  484. default_headers: Mapping[str, str] | None = None,
  485. set_default_headers: Mapping[str, str] | None = None,
  486. default_query: Mapping[str, object] | None = None,
  487. set_default_query: Mapping[str, object] | None = None,
  488. _extra_kwargs: Mapping[str, Any] = {},
  489. ) -> Self:
  490. """
  491. Create a new client instance re-using the same options given to the current client with optional overriding.
  492. """
  493. return super().copy(
  494. api_key=api_key,
  495. organization=organization,
  496. project=project,
  497. webhook_secret=webhook_secret,
  498. websocket_base_url=websocket_base_url,
  499. base_url=base_url,
  500. timeout=timeout,
  501. http_client=http_client,
  502. max_retries=max_retries,
  503. default_headers=default_headers,
  504. set_default_headers=set_default_headers,
  505. default_query=default_query,
  506. set_default_query=set_default_query,
  507. _extra_kwargs={
  508. "api_version": api_version or self._api_version,
  509. "azure_ad_token": azure_ad_token or self._azure_ad_token,
  510. "azure_ad_token_provider": azure_ad_token_provider or self._azure_ad_token_provider,
  511. **_extra_kwargs,
  512. },
  513. )
  514. with_options = copy
  515. async def _get_azure_ad_token(self) -> str | None:
  516. if self._azure_ad_token is not None:
  517. return self._azure_ad_token
  518. provider = self._azure_ad_token_provider
  519. if provider is not None:
  520. token = provider()
  521. if inspect.isawaitable(token):
  522. token = await token
  523. if not token or not isinstance(cast(Any, token), str):
  524. raise ValueError(
  525. f"Expected `azure_ad_token_provider` argument to return a string but it returned {token}",
  526. )
  527. return str(token)
  528. return None
  529. @override
  530. async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
  531. headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {}
  532. options = model_copy(options)
  533. options.headers = headers
  534. azure_ad_token = await self._get_azure_ad_token()
  535. if azure_ad_token is not None:
  536. if headers.get("Authorization") is None:
  537. headers["Authorization"] = f"Bearer {azure_ad_token}"
  538. elif self.api_key is not API_KEY_SENTINEL:
  539. if headers.get("api-key") is None:
  540. headers["api-key"] = self.api_key
  541. else:
  542. # should never be hit
  543. raise ValueError("Unable to handle auth")
  544. return options
  545. async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL, dict[str, str]]:
  546. auth_headers = {}
  547. query = {
  548. **extra_query,
  549. "api-version": self._api_version,
  550. "deployment": self._azure_deployment or model,
  551. }
  552. if self.api_key and self.api_key != "<missing API key>":
  553. auth_headers = {"api-key": self.api_key}
  554. else:
  555. token = await self._get_azure_ad_token()
  556. if token:
  557. auth_headers = {"Authorization": f"Bearer {token}"}
  558. if self.websocket_base_url is not None:
  559. base_url = httpx.URL(self.websocket_base_url)
  560. merge_raw_path = base_url.raw_path.rstrip(b"/") + b"/realtime"
  561. realtime_url = base_url.copy_with(raw_path=merge_raw_path)
  562. else:
  563. base_url = self._prepare_url("/realtime")
  564. realtime_url = base_url.copy_with(scheme="wss")
  565. url = realtime_url.copy_with(params={**query})
  566. return url, auth_headers