_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. from __future__ import annotations
  2. import os
  3. import re
  4. import inspect
  5. import functools
  6. from typing import (
  7. TYPE_CHECKING,
  8. Any,
  9. Tuple,
  10. Mapping,
  11. TypeVar,
  12. Callable,
  13. Iterable,
  14. Sequence,
  15. cast,
  16. overload,
  17. )
  18. from pathlib import Path
  19. from datetime import date, datetime
  20. from typing_extensions import TypeGuard
  21. import sniffio
  22. from .._types import Omit, NotGiven, FileTypes, HeadersLike
  23. _T = TypeVar("_T")
  24. _TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
  25. _MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
  26. _SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
  27. CallableT = TypeVar("CallableT", bound=Callable[..., Any])
  28. if TYPE_CHECKING:
  29. from ..lib.azure import AzureOpenAI, AsyncAzureOpenAI
  30. def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
  31. return [item for sublist in t for item in sublist]
  32. def extract_files(
  33. # TODO: this needs to take Dict but variance issues.....
  34. # create protocol type ?
  35. query: Mapping[str, object],
  36. *,
  37. paths: Sequence[Sequence[str]],
  38. ) -> list[tuple[str, FileTypes]]:
  39. """Recursively extract files from the given dictionary based on specified paths.
  40. A path may look like this ['foo', 'files', '<array>', 'data'].
  41. Note: this mutates the given dictionary.
  42. """
  43. files: list[tuple[str, FileTypes]] = []
  44. for path in paths:
  45. files.extend(_extract_items(query, path, index=0, flattened_key=None))
  46. return files
  47. def _extract_items(
  48. obj: object,
  49. path: Sequence[str],
  50. *,
  51. index: int,
  52. flattened_key: str | None,
  53. ) -> list[tuple[str, FileTypes]]:
  54. try:
  55. key = path[index]
  56. except IndexError:
  57. if not is_given(obj):
  58. # no value was provided - we can safely ignore
  59. return []
  60. # cyclical import
  61. from .._files import assert_is_file_content
  62. # We have exhausted the path, return the entry we found.
  63. assert flattened_key is not None
  64. if is_list(obj):
  65. files: list[tuple[str, FileTypes]] = []
  66. for entry in obj:
  67. assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "")
  68. files.append((flattened_key + "[]", cast(FileTypes, entry)))
  69. return files
  70. assert_is_file_content(obj, key=flattened_key)
  71. return [(flattened_key, cast(FileTypes, obj))]
  72. index += 1
  73. if is_dict(obj):
  74. try:
  75. # We are at the last entry in the path so we must remove the field
  76. if (len(path)) == index:
  77. item = obj.pop(key)
  78. else:
  79. item = obj[key]
  80. except KeyError:
  81. # Key was not present in the dictionary, this is not indicative of an error
  82. # as the given path may not point to a required field. We also do not want
  83. # to enforce required fields as the API may differ from the spec in some cases.
  84. return []
  85. if flattened_key is None:
  86. flattened_key = key
  87. else:
  88. flattened_key += f"[{key}]"
  89. return _extract_items(
  90. item,
  91. path,
  92. index=index,
  93. flattened_key=flattened_key,
  94. )
  95. elif is_list(obj):
  96. if key != "<array>":
  97. return []
  98. return flatten(
  99. [
  100. _extract_items(
  101. item,
  102. path,
  103. index=index,
  104. flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
  105. )
  106. for item in obj
  107. ]
  108. )
  109. # Something unexpected was passed, just ignore it.
  110. return []
  111. def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]:
  112. return not isinstance(obj, NotGiven) and not isinstance(obj, Omit)
  113. # Type safe methods for narrowing types with TypeVars.
  114. # The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
  115. # however this cause Pyright to rightfully report errors. As we know we don't
  116. # care about the contained types we can safely use `object` in its place.
  117. #
  118. # There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
  119. # `is_*` is for when you're dealing with an unknown input
  120. # `is_*_t` is for when you're narrowing a known union type to a specific subset
  121. def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:
  122. return isinstance(obj, tuple)
  123. def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:
  124. return isinstance(obj, tuple)
  125. def is_sequence(obj: object) -> TypeGuard[Sequence[object]]:
  126. return isinstance(obj, Sequence)
  127. def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:
  128. return isinstance(obj, Sequence)
  129. def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:
  130. return isinstance(obj, Mapping)
  131. def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:
  132. return isinstance(obj, Mapping)
  133. def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
  134. return isinstance(obj, dict)
  135. def is_list(obj: object) -> TypeGuard[list[object]]:
  136. return isinstance(obj, list)
  137. def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
  138. return isinstance(obj, Iterable)
  139. def deepcopy_minimal(item: _T) -> _T:
  140. """Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
  141. - mappings, e.g. `dict`
  142. - list
  143. This is done for performance reasons.
  144. """
  145. if is_mapping(item):
  146. return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})
  147. if is_list(item):
  148. return cast(_T, [deepcopy_minimal(entry) for entry in item])
  149. return item
  150. # copied from https://github.com/Rapptz/RoboDanny
  151. def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
  152. size = len(seq)
  153. if size == 0:
  154. return ""
  155. if size == 1:
  156. return seq[0]
  157. if size == 2:
  158. return f"{seq[0]} {final} {seq[1]}"
  159. return delim.join(seq[:-1]) + f" {final} {seq[-1]}"
  160. def quote(string: str) -> str:
  161. """Add single quotation marks around the given string. Does *not* do any escaping."""
  162. return f"'{string}'"
  163. def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
  164. """Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
  165. Useful for enforcing runtime validation of overloaded functions.
  166. Example usage:
  167. ```py
  168. @overload
  169. def foo(*, a: str) -> str: ...
  170. @overload
  171. def foo(*, b: bool) -> str: ...
  172. # This enforces the same constraints that a static type checker would
  173. # i.e. that either a or b must be passed to the function
  174. @required_args(["a"], ["b"])
  175. def foo(*, a: str | None = None, b: bool | None = None) -> str: ...
  176. ```
  177. """
  178. def inner(func: CallableT) -> CallableT:
  179. params = inspect.signature(func).parameters
  180. positional = [
  181. name
  182. for name, param in params.items()
  183. if param.kind
  184. in {
  185. param.POSITIONAL_ONLY,
  186. param.POSITIONAL_OR_KEYWORD,
  187. }
  188. ]
  189. @functools.wraps(func)
  190. def wrapper(*args: object, **kwargs: object) -> object:
  191. given_params: set[str] = set()
  192. for i, _ in enumerate(args):
  193. try:
  194. given_params.add(positional[i])
  195. except IndexError:
  196. raise TypeError(
  197. f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
  198. ) from None
  199. for key in kwargs.keys():
  200. given_params.add(key)
  201. for variant in variants:
  202. matches = all((param in given_params for param in variant))
  203. if matches:
  204. break
  205. else: # no break
  206. if len(variants) > 1:
  207. variations = human_join(
  208. ["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
  209. )
  210. msg = f"Missing required arguments; Expected either {variations} arguments to be given"
  211. else:
  212. assert len(variants) > 0
  213. # TODO: this error message is not deterministic
  214. missing = list(set(variants[0]) - given_params)
  215. if len(missing) > 1:
  216. msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
  217. else:
  218. msg = f"Missing required argument: {quote(missing[0])}"
  219. raise TypeError(msg)
  220. return func(*args, **kwargs)
  221. return wrapper # type: ignore
  222. return inner
  223. _K = TypeVar("_K")
  224. _V = TypeVar("_V")
  225. @overload
  226. def strip_not_given(obj: None) -> None: ...
  227. @overload
  228. def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ...
  229. @overload
  230. def strip_not_given(obj: object) -> object: ...
  231. def strip_not_given(obj: object | None) -> object:
  232. """Remove all top-level keys where their values are instances of `NotGiven`"""
  233. if obj is None:
  234. return None
  235. if not is_mapping(obj):
  236. return obj
  237. return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
  238. def coerce_integer(val: str) -> int:
  239. return int(val, base=10)
  240. def coerce_float(val: str) -> float:
  241. return float(val)
  242. def coerce_boolean(val: str) -> bool:
  243. return val == "true" or val == "1" or val == "on"
  244. def maybe_coerce_integer(val: str | None) -> int | None:
  245. if val is None:
  246. return None
  247. return coerce_integer(val)
  248. def maybe_coerce_float(val: str | None) -> float | None:
  249. if val is None:
  250. return None
  251. return coerce_float(val)
  252. def maybe_coerce_boolean(val: str | None) -> bool | None:
  253. if val is None:
  254. return None
  255. return coerce_boolean(val)
  256. def removeprefix(string: str, prefix: str) -> str:
  257. """Remove a prefix from a string.
  258. Backport of `str.removeprefix` for Python < 3.9
  259. """
  260. if string.startswith(prefix):
  261. return string[len(prefix) :]
  262. return string
  263. def removesuffix(string: str, suffix: str) -> str:
  264. """Remove a suffix from a string.
  265. Backport of `str.removesuffix` for Python < 3.9
  266. """
  267. if string.endswith(suffix):
  268. return string[: -len(suffix)]
  269. return string
  270. def file_from_path(path: str) -> FileTypes:
  271. contents = Path(path).read_bytes()
  272. file_name = os.path.basename(path)
  273. return (file_name, contents)
  274. def get_required_header(headers: HeadersLike, header: str) -> str:
  275. lower_header = header.lower()
  276. if is_mapping_t(headers):
  277. # mypy doesn't understand the type narrowing here
  278. for k, v in headers.items(): # type: ignore
  279. if k.lower() == lower_header and isinstance(v, str):
  280. return v
  281. # to deal with the case where the header looks like Stainless-Event-Id
  282. intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
  283. for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
  284. value = headers.get(normalized_header)
  285. if value:
  286. return value
  287. raise ValueError(f"Could not find {header} header")
  288. def get_async_library() -> str:
  289. try:
  290. return sniffio.current_async_library()
  291. except Exception:
  292. return "false"
  293. def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
  294. """A version of functools.lru_cache that retains the type signature
  295. for the wrapped function arguments.
  296. """
  297. wrapper = functools.lru_cache( # noqa: TID251
  298. maxsize=maxsize,
  299. )
  300. return cast(Any, wrapper) # type: ignore[no-any-return]
  301. def json_safe(data: object) -> object:
  302. """Translates a mapping / sequence recursively in the same fashion
  303. as `pydantic` v2's `model_dump(mode="json")`.
  304. """
  305. if is_mapping(data):
  306. return {json_safe(key): json_safe(value) for key, value in data.items()}
  307. if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)):
  308. return [json_safe(item) for item in data]
  309. if isinstance(data, (datetime, date)):
  310. return data.isoformat()
  311. return data
  312. def is_azure_client(client: object) -> TypeGuard[AzureOpenAI]:
  313. from ..lib.azure import AzureOpenAI
  314. return isinstance(client, AzureOpenAI)
  315. def is_async_azure_client(client: object) -> TypeGuard[AsyncAzureOpenAI]:
  316. from ..lib.azure import AsyncAzureOpenAI
  317. return isinstance(client, AsyncAzureOpenAI)