| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437 |
- from __future__ import annotations
- import os
- import re
- import inspect
- import functools
- from typing import (
- TYPE_CHECKING,
- Any,
- Tuple,
- Mapping,
- TypeVar,
- Callable,
- Iterable,
- Sequence,
- cast,
- overload,
- )
- from pathlib import Path
- from datetime import date, datetime
- from typing_extensions import TypeGuard
- import sniffio
- from .._types import Omit, NotGiven, FileTypes, HeadersLike
- _T = TypeVar("_T")
- _TupleT = TypeVar("_TupleT", bound=Tuple[object, ...])
- _MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
- _SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
- CallableT = TypeVar("CallableT", bound=Callable[..., Any])
- if TYPE_CHECKING:
- from ..lib.azure import AzureOpenAI, AsyncAzureOpenAI
- def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
- return [item for sublist in t for item in sublist]
- def extract_files(
- # TODO: this needs to take Dict but variance issues.....
- # create protocol type ?
- query: Mapping[str, object],
- *,
- paths: Sequence[Sequence[str]],
- ) -> list[tuple[str, FileTypes]]:
- """Recursively extract files from the given dictionary based on specified paths.
- A path may look like this ['foo', 'files', '<array>', 'data'].
- Note: this mutates the given dictionary.
- """
- files: list[tuple[str, FileTypes]] = []
- for path in paths:
- files.extend(_extract_items(query, path, index=0, flattened_key=None))
- return files
- def _extract_items(
- obj: object,
- path: Sequence[str],
- *,
- index: int,
- flattened_key: str | None,
- ) -> list[tuple[str, FileTypes]]:
- try:
- key = path[index]
- except IndexError:
- if not is_given(obj):
- # no value was provided - we can safely ignore
- return []
- # cyclical import
- from .._files import assert_is_file_content
- # We have exhausted the path, return the entry we found.
- assert flattened_key is not None
- if is_list(obj):
- files: list[tuple[str, FileTypes]] = []
- for entry in obj:
- assert_is_file_content(entry, key=flattened_key + "[]" if flattened_key else "")
- files.append((flattened_key + "[]", cast(FileTypes, entry)))
- return files
- assert_is_file_content(obj, key=flattened_key)
- return [(flattened_key, cast(FileTypes, obj))]
- index += 1
- if is_dict(obj):
- try:
- # We are at the last entry in the path so we must remove the field
- if (len(path)) == index:
- item = obj.pop(key)
- else:
- item = obj[key]
- except KeyError:
- # Key was not present in the dictionary, this is not indicative of an error
- # as the given path may not point to a required field. We also do not want
- # to enforce required fields as the API may differ from the spec in some cases.
- return []
- if flattened_key is None:
- flattened_key = key
- else:
- flattened_key += f"[{key}]"
- return _extract_items(
- item,
- path,
- index=index,
- flattened_key=flattened_key,
- )
- elif is_list(obj):
- if key != "<array>":
- return []
- return flatten(
- [
- _extract_items(
- item,
- path,
- index=index,
- flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
- )
- for item in obj
- ]
- )
- # Something unexpected was passed, just ignore it.
- return []
- def is_given(obj: _T | NotGiven | Omit) -> TypeGuard[_T]:
- return not isinstance(obj, NotGiven) and not isinstance(obj, Omit)
- # Type safe methods for narrowing types with TypeVars.
- # The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
- # however this cause Pyright to rightfully report errors. As we know we don't
- # care about the contained types we can safely use `object` in its place.
- #
- # There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
- # `is_*` is for when you're dealing with an unknown input
- # `is_*_t` is for when you're narrowing a known union type to a specific subset
- def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:
- return isinstance(obj, tuple)
- def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:
- return isinstance(obj, tuple)
- def is_sequence(obj: object) -> TypeGuard[Sequence[object]]:
- return isinstance(obj, Sequence)
- def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:
- return isinstance(obj, Sequence)
- def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:
- return isinstance(obj, Mapping)
- def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:
- return isinstance(obj, Mapping)
- def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
- return isinstance(obj, dict)
- def is_list(obj: object) -> TypeGuard[list[object]]:
- return isinstance(obj, list)
- def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
- return isinstance(obj, Iterable)
- def deepcopy_minimal(item: _T) -> _T:
- """Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
- - mappings, e.g. `dict`
- - list
- This is done for performance reasons.
- """
- if is_mapping(item):
- return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})
- if is_list(item):
- return cast(_T, [deepcopy_minimal(entry) for entry in item])
- return item
- # copied from https://github.com/Rapptz/RoboDanny
- def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
- size = len(seq)
- if size == 0:
- return ""
- if size == 1:
- return seq[0]
- if size == 2:
- return f"{seq[0]} {final} {seq[1]}"
- return delim.join(seq[:-1]) + f" {final} {seq[-1]}"
- def quote(string: str) -> str:
- """Add single quotation marks around the given string. Does *not* do any escaping."""
- return f"'{string}'"
- def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
- """Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
- Useful for enforcing runtime validation of overloaded functions.
- Example usage:
- ```py
- @overload
- def foo(*, a: str) -> str: ...
- @overload
- def foo(*, b: bool) -> str: ...
- # This enforces the same constraints that a static type checker would
- # i.e. that either a or b must be passed to the function
- @required_args(["a"], ["b"])
- def foo(*, a: str | None = None, b: bool | None = None) -> str: ...
- ```
- """
- def inner(func: CallableT) -> CallableT:
- params = inspect.signature(func).parameters
- positional = [
- name
- for name, param in params.items()
- if param.kind
- in {
- param.POSITIONAL_ONLY,
- param.POSITIONAL_OR_KEYWORD,
- }
- ]
- @functools.wraps(func)
- def wrapper(*args: object, **kwargs: object) -> object:
- given_params: set[str] = set()
- for i, _ in enumerate(args):
- try:
- given_params.add(positional[i])
- except IndexError:
- raise TypeError(
- f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
- ) from None
- for key in kwargs.keys():
- given_params.add(key)
- for variant in variants:
- matches = all((param in given_params for param in variant))
- if matches:
- break
- else: # no break
- if len(variants) > 1:
- variations = human_join(
- ["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
- )
- msg = f"Missing required arguments; Expected either {variations} arguments to be given"
- else:
- assert len(variants) > 0
- # TODO: this error message is not deterministic
- missing = list(set(variants[0]) - given_params)
- if len(missing) > 1:
- msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
- else:
- msg = f"Missing required argument: {quote(missing[0])}"
- raise TypeError(msg)
- return func(*args, **kwargs)
- return wrapper # type: ignore
- return inner
- _K = TypeVar("_K")
- _V = TypeVar("_V")
- @overload
- def strip_not_given(obj: None) -> None: ...
- @overload
- def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ...
- @overload
- def strip_not_given(obj: object) -> object: ...
- def strip_not_given(obj: object | None) -> object:
- """Remove all top-level keys where their values are instances of `NotGiven`"""
- if obj is None:
- return None
- if not is_mapping(obj):
- return obj
- return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
- def coerce_integer(val: str) -> int:
- return int(val, base=10)
- def coerce_float(val: str) -> float:
- return float(val)
- def coerce_boolean(val: str) -> bool:
- return val == "true" or val == "1" or val == "on"
- def maybe_coerce_integer(val: str | None) -> int | None:
- if val is None:
- return None
- return coerce_integer(val)
- def maybe_coerce_float(val: str | None) -> float | None:
- if val is None:
- return None
- return coerce_float(val)
- def maybe_coerce_boolean(val: str | None) -> bool | None:
- if val is None:
- return None
- return coerce_boolean(val)
- def removeprefix(string: str, prefix: str) -> str:
- """Remove a prefix from a string.
- Backport of `str.removeprefix` for Python < 3.9
- """
- if string.startswith(prefix):
- return string[len(prefix) :]
- return string
- def removesuffix(string: str, suffix: str) -> str:
- """Remove a suffix from a string.
- Backport of `str.removesuffix` for Python < 3.9
- """
- if string.endswith(suffix):
- return string[: -len(suffix)]
- return string
- def file_from_path(path: str) -> FileTypes:
- contents = Path(path).read_bytes()
- file_name = os.path.basename(path)
- return (file_name, contents)
- def get_required_header(headers: HeadersLike, header: str) -> str:
- lower_header = header.lower()
- if is_mapping_t(headers):
- # mypy doesn't understand the type narrowing here
- for k, v in headers.items(): # type: ignore
- if k.lower() == lower_header and isinstance(v, str):
- return v
- # to deal with the case where the header looks like Stainless-Event-Id
- intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
- for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
- value = headers.get(normalized_header)
- if value:
- return value
- raise ValueError(f"Could not find {header} header")
- def get_async_library() -> str:
- try:
- return sniffio.current_async_library()
- except Exception:
- return "false"
- def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
- """A version of functools.lru_cache that retains the type signature
- for the wrapped function arguments.
- """
- wrapper = functools.lru_cache( # noqa: TID251
- maxsize=maxsize,
- )
- return cast(Any, wrapper) # type: ignore[no-any-return]
- def json_safe(data: object) -> object:
- """Translates a mapping / sequence recursively in the same fashion
- as `pydantic` v2's `model_dump(mode="json")`.
- """
- if is_mapping(data):
- return {json_safe(key): json_safe(value) for key, value in data.items()}
- if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)):
- return [json_safe(item) for item in data]
- if isinstance(data, (datetime, date)):
- return data.isoformat()
- return data
- def is_azure_client(client: object) -> TypeGuard[AzureOpenAI]:
- from ..lib.azure import AzureOpenAI
- return isinstance(client, AzureOpenAI)
- def is_async_azure_client(client: object) -> TypeGuard[AsyncAzureOpenAI]:
- from ..lib.azure import AsyncAzureOpenAI
- return isinstance(client, AsyncAzureOpenAI)
|