| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457 |
- from __future__ import annotations
- import io
- import base64
- import pathlib
- from typing import Any, Mapping, TypeVar, cast
- from datetime import date, datetime
- from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints
- import anyio
- import pydantic
- from ._utils import (
- is_list,
- is_given,
- lru_cache,
- is_mapping,
- is_iterable,
- is_sequence,
- )
- from .._files import is_base64_file_input
- from ._compat import get_origin, is_typeddict
- from ._typing import (
- is_list_type,
- is_union_type,
- extract_type_arg,
- is_iterable_type,
- is_required_type,
- is_sequence_type,
- is_annotated_type,
- strip_annotated_type,
- )
- _T = TypeVar("_T")
- # TODO: support for drilling globals() and locals()
- # TODO: ensure works correctly with forward references in all cases
- PropertyFormat = Literal["iso8601", "base64", "custom"]
- class PropertyInfo:
- """Metadata class to be used in Annotated types to provide information about a given type.
- For example:
- class MyParams(TypedDict):
- account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]
- This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.
- """
- alias: str | None
- format: PropertyFormat | None
- format_template: str | None
- discriminator: str | None
- def __init__(
- self,
- *,
- alias: str | None = None,
- format: PropertyFormat | None = None,
- format_template: str | None = None,
- discriminator: str | None = None,
- ) -> None:
- self.alias = alias
- self.format = format
- self.format_template = format_template
- self.discriminator = discriminator
- @override
- def __repr__(self) -> str:
- return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')"
- def maybe_transform(
- data: object,
- expected_type: object,
- ) -> Any | None:
- """Wrapper over `transform()` that allows `None` to be passed.
- See `transform()` for more details.
- """
- if data is None:
- return None
- return transform(data, expected_type)
- # Wrapper over _transform_recursive providing fake types
- def transform(
- data: _T,
- expected_type: object,
- ) -> _T:
- """Transform dictionaries based off of type information from the given type, for example:
- ```py
- class Params(TypedDict, total=False):
- card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
- transformed = transform({"card_id": "<my card ID>"}, Params)
- # {'cardID': '<my card ID>'}
- ```
- Any keys / data that does not have type information given will be included as is.
- It should be noted that the transformations that this function does are not represented in the type system.
- """
- transformed = _transform_recursive(data, annotation=cast(type, expected_type))
- return cast(_T, transformed)
- @lru_cache(maxsize=8096)
- def _get_annotated_type(type_: type) -> type | None:
- """If the given type is an `Annotated` type then it is returned, if not `None` is returned.
- This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`
- """
- if is_required_type(type_):
- # Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`
- type_ = get_args(type_)[0]
- if is_annotated_type(type_):
- return type_
- return None
- def _maybe_transform_key(key: str, type_: type) -> str:
- """Transform the given `data` based on the annotations provided in `type_`.
- Note: this function only looks at `Annotated` types that contain `PropertyInfo` metadata.
- """
- annotated_type = _get_annotated_type(type_)
- if annotated_type is None:
- # no `Annotated` definition for this type, no transformation needed
- return key
- # ignore the first argument as it is the actual type
- annotations = get_args(annotated_type)[1:]
- for annotation in annotations:
- if isinstance(annotation, PropertyInfo) and annotation.alias is not None:
- return annotation.alias
- return key
- def _no_transform_needed(annotation: type) -> bool:
- return annotation == float or annotation == int
- def _transform_recursive(
- data: object,
- *,
- annotation: type,
- inner_type: type | None = None,
- ) -> object:
- """Transform the given data against the expected type.
- Args:
- annotation: The direct type annotation given to the particular piece of data.
- This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
- inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
- is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
- the list can be transformed using the metadata from the container type.
- Defaults to the same value as the `annotation` argument.
- """
- from .._compat import model_dump
- if inner_type is None:
- inner_type = annotation
- stripped_type = strip_annotated_type(inner_type)
- origin = get_origin(stripped_type) or stripped_type
- if is_typeddict(stripped_type) and is_mapping(data):
- return _transform_typeddict(data, stripped_type)
- if origin == dict and is_mapping(data):
- items_type = get_args(stripped_type)[1]
- return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
- if (
- # List[T]
- (is_list_type(stripped_type) and is_list(data))
- # Iterable[T]
- or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
- # Sequence[T]
- or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
- ):
- # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
- # intended as an iterable, so we don't transform it.
- if isinstance(data, dict):
- return cast(object, data)
- inner_type = extract_type_arg(stripped_type, 0)
- if _no_transform_needed(inner_type):
- # for some types there is no need to transform anything, so we can get a small
- # perf boost from skipping that work.
- #
- # but we still need to convert to a list to ensure the data is json-serializable
- if is_list(data):
- return data
- return list(data)
- return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
- if is_union_type(stripped_type):
- # For union types we run the transformation against all subtypes to ensure that everything is transformed.
- #
- # TODO: there may be edge cases where the same normalized field name will transform to two different names
- # in different subtypes.
- for subtype in get_args(stripped_type):
- data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
- return data
- if isinstance(data, pydantic.BaseModel):
- return model_dump(data, exclude_unset=True, mode="json", exclude=getattr(data, "__api_exclude__", None))
- annotated_type = _get_annotated_type(annotation)
- if annotated_type is None:
- return data
- # ignore the first argument as it is the actual type
- annotations = get_args(annotated_type)[1:]
- for annotation in annotations:
- if isinstance(annotation, PropertyInfo) and annotation.format is not None:
- return _format_data(data, annotation.format, annotation.format_template)
- return data
- def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
- if isinstance(data, (date, datetime)):
- if format_ == "iso8601":
- return data.isoformat()
- if format_ == "custom" and format_template is not None:
- return data.strftime(format_template)
- if format_ == "base64" and is_base64_file_input(data):
- binary: str | bytes | None = None
- if isinstance(data, pathlib.Path):
- binary = data.read_bytes()
- elif isinstance(data, io.IOBase):
- binary = data.read()
- if isinstance(binary, str): # type: ignore[unreachable]
- binary = binary.encode()
- if not isinstance(binary, bytes):
- raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
- return base64.b64encode(binary).decode("ascii")
- return data
- def _transform_typeddict(
- data: Mapping[str, object],
- expected_type: type,
- ) -> Mapping[str, object]:
- result: dict[str, object] = {}
- annotations = get_type_hints(expected_type, include_extras=True)
- for key, value in data.items():
- if not is_given(value):
- # we don't need to include omitted values here as they'll
- # be stripped out before the request is sent anyway
- continue
- type_ = annotations.get(key)
- if type_ is None:
- # we do not have a type annotation for this field, leave it as is
- result[key] = value
- else:
- result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
- return result
- async def async_maybe_transform(
- data: object,
- expected_type: object,
- ) -> Any | None:
- """Wrapper over `async_transform()` that allows `None` to be passed.
- See `async_transform()` for more details.
- """
- if data is None:
- return None
- return await async_transform(data, expected_type)
- async def async_transform(
- data: _T,
- expected_type: object,
- ) -> _T:
- """Transform dictionaries based off of type information from the given type, for example:
- ```py
- class Params(TypedDict, total=False):
- card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
- transformed = transform({"card_id": "<my card ID>"}, Params)
- # {'cardID': '<my card ID>'}
- ```
- Any keys / data that does not have type information given will be included as is.
- It should be noted that the transformations that this function does are not represented in the type system.
- """
- transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
- return cast(_T, transformed)
- async def _async_transform_recursive(
- data: object,
- *,
- annotation: type,
- inner_type: type | None = None,
- ) -> object:
- """Transform the given data against the expected type.
- Args:
- annotation: The direct type annotation given to the particular piece of data.
- This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
- inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
- is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
- the list can be transformed using the metadata from the container type.
- Defaults to the same value as the `annotation` argument.
- """
- from .._compat import model_dump
- if inner_type is None:
- inner_type = annotation
- stripped_type = strip_annotated_type(inner_type)
- origin = get_origin(stripped_type) or stripped_type
- if is_typeddict(stripped_type) and is_mapping(data):
- return await _async_transform_typeddict(data, stripped_type)
- if origin == dict and is_mapping(data):
- items_type = get_args(stripped_type)[1]
- return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
- if (
- # List[T]
- (is_list_type(stripped_type) and is_list(data))
- # Iterable[T]
- or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
- # Sequence[T]
- or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
- ):
- # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
- # intended as an iterable, so we don't transform it.
- if isinstance(data, dict):
- return cast(object, data)
- inner_type = extract_type_arg(stripped_type, 0)
- if _no_transform_needed(inner_type):
- # for some types there is no need to transform anything, so we can get a small
- # perf boost from skipping that work.
- #
- # but we still need to convert to a list to ensure the data is json-serializable
- if is_list(data):
- return data
- return list(data)
- return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
- if is_union_type(stripped_type):
- # For union types we run the transformation against all subtypes to ensure that everything is transformed.
- #
- # TODO: there may be edge cases where the same normalized field name will transform to two different names
- # in different subtypes.
- for subtype in get_args(stripped_type):
- data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
- return data
- if isinstance(data, pydantic.BaseModel):
- return model_dump(data, exclude_unset=True, mode="json")
- annotated_type = _get_annotated_type(annotation)
- if annotated_type is None:
- return data
- # ignore the first argument as it is the actual type
- annotations = get_args(annotated_type)[1:]
- for annotation in annotations:
- if isinstance(annotation, PropertyInfo) and annotation.format is not None:
- return await _async_format_data(data, annotation.format, annotation.format_template)
- return data
- async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
- if isinstance(data, (date, datetime)):
- if format_ == "iso8601":
- return data.isoformat()
- if format_ == "custom" and format_template is not None:
- return data.strftime(format_template)
- if format_ == "base64" and is_base64_file_input(data):
- binary: str | bytes | None = None
- if isinstance(data, pathlib.Path):
- binary = await anyio.Path(data).read_bytes()
- elif isinstance(data, io.IOBase):
- binary = data.read()
- if isinstance(binary, str): # type: ignore[unreachable]
- binary = binary.encode()
- if not isinstance(binary, bytes):
- raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
- return base64.b64encode(binary).decode("ascii")
- return data
- async def _async_transform_typeddict(
- data: Mapping[str, object],
- expected_type: type,
- ) -> Mapping[str, object]:
- result: dict[str, object] = {}
- annotations = get_type_hints(expected_type, include_extras=True)
- for key, value in data.items():
- if not is_given(value):
- # we don't need to include omitted values here as they'll
- # be stripped out before the request is sent anyway
- continue
- type_ = annotations.get(key)
- if type_ is None:
- # we do not have a type annotation for this field, leave it as is
- result[key] = value
- else:
- result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
- return result
- @lru_cache(maxsize=8096)
- def get_type_hints(
- obj: Any,
- globalns: dict[str, Any] | None = None,
- localns: Mapping[str, Any] | None = None,
- include_extras: bool = False,
- ) -> dict[str, Any]:
- return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)
|