_transform.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. from __future__ import annotations
  2. import io
  3. import base64
  4. import pathlib
  5. from typing import Any, Mapping, TypeVar, cast
  6. from datetime import date, datetime
  7. from typing_extensions import Literal, get_args, override, get_type_hints as _get_type_hints
  8. import anyio
  9. import pydantic
  10. from ._utils import (
  11. is_list,
  12. is_given,
  13. lru_cache,
  14. is_mapping,
  15. is_iterable,
  16. is_sequence,
  17. )
  18. from .._files import is_base64_file_input
  19. from ._compat import get_origin, is_typeddict
  20. from ._typing import (
  21. is_list_type,
  22. is_union_type,
  23. extract_type_arg,
  24. is_iterable_type,
  25. is_required_type,
  26. is_sequence_type,
  27. is_annotated_type,
  28. strip_annotated_type,
  29. )
  30. _T = TypeVar("_T")
  31. # TODO: support for drilling globals() and locals()
  32. # TODO: ensure works correctly with forward references in all cases
  33. PropertyFormat = Literal["iso8601", "base64", "custom"]
  34. class PropertyInfo:
  35. """Metadata class to be used in Annotated types to provide information about a given type.
  36. For example:
  37. class MyParams(TypedDict):
  38. account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]
  39. This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.
  40. """
  41. alias: str | None
  42. format: PropertyFormat | None
  43. format_template: str | None
  44. discriminator: str | None
  45. def __init__(
  46. self,
  47. *,
  48. alias: str | None = None,
  49. format: PropertyFormat | None = None,
  50. format_template: str | None = None,
  51. discriminator: str | None = None,
  52. ) -> None:
  53. self.alias = alias
  54. self.format = format
  55. self.format_template = format_template
  56. self.discriminator = discriminator
  57. @override
  58. def __repr__(self) -> str:
  59. return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')"
  60. def maybe_transform(
  61. data: object,
  62. expected_type: object,
  63. ) -> Any | None:
  64. """Wrapper over `transform()` that allows `None` to be passed.
  65. See `transform()` for more details.
  66. """
  67. if data is None:
  68. return None
  69. return transform(data, expected_type)
  70. # Wrapper over _transform_recursive providing fake types
  71. def transform(
  72. data: _T,
  73. expected_type: object,
  74. ) -> _T:
  75. """Transform dictionaries based off of type information from the given type, for example:
  76. ```py
  77. class Params(TypedDict, total=False):
  78. card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
  79. transformed = transform({"card_id": "<my card ID>"}, Params)
  80. # {'cardID': '<my card ID>'}
  81. ```
  82. Any keys / data that does not have type information given will be included as is.
  83. It should be noted that the transformations that this function does are not represented in the type system.
  84. """
  85. transformed = _transform_recursive(data, annotation=cast(type, expected_type))
  86. return cast(_T, transformed)
  87. @lru_cache(maxsize=8096)
  88. def _get_annotated_type(type_: type) -> type | None:
  89. """If the given type is an `Annotated` type then it is returned, if not `None` is returned.
  90. This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`
  91. """
  92. if is_required_type(type_):
  93. # Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`
  94. type_ = get_args(type_)[0]
  95. if is_annotated_type(type_):
  96. return type_
  97. return None
  98. def _maybe_transform_key(key: str, type_: type) -> str:
  99. """Transform the given `data` based on the annotations provided in `type_`.
  100. Note: this function only looks at `Annotated` types that contain `PropertyInfo` metadata.
  101. """
  102. annotated_type = _get_annotated_type(type_)
  103. if annotated_type is None:
  104. # no `Annotated` definition for this type, no transformation needed
  105. return key
  106. # ignore the first argument as it is the actual type
  107. annotations = get_args(annotated_type)[1:]
  108. for annotation in annotations:
  109. if isinstance(annotation, PropertyInfo) and annotation.alias is not None:
  110. return annotation.alias
  111. return key
  112. def _no_transform_needed(annotation: type) -> bool:
  113. return annotation == float or annotation == int
  114. def _transform_recursive(
  115. data: object,
  116. *,
  117. annotation: type,
  118. inner_type: type | None = None,
  119. ) -> object:
  120. """Transform the given data against the expected type.
  121. Args:
  122. annotation: The direct type annotation given to the particular piece of data.
  123. This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
  124. inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
  125. is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
  126. the list can be transformed using the metadata from the container type.
  127. Defaults to the same value as the `annotation` argument.
  128. """
  129. from .._compat import model_dump
  130. if inner_type is None:
  131. inner_type = annotation
  132. stripped_type = strip_annotated_type(inner_type)
  133. origin = get_origin(stripped_type) or stripped_type
  134. if is_typeddict(stripped_type) and is_mapping(data):
  135. return _transform_typeddict(data, stripped_type)
  136. if origin == dict and is_mapping(data):
  137. items_type = get_args(stripped_type)[1]
  138. return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
  139. if (
  140. # List[T]
  141. (is_list_type(stripped_type) and is_list(data))
  142. # Iterable[T]
  143. or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
  144. # Sequence[T]
  145. or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
  146. ):
  147. # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
  148. # intended as an iterable, so we don't transform it.
  149. if isinstance(data, dict):
  150. return cast(object, data)
  151. inner_type = extract_type_arg(stripped_type, 0)
  152. if _no_transform_needed(inner_type):
  153. # for some types there is no need to transform anything, so we can get a small
  154. # perf boost from skipping that work.
  155. #
  156. # but we still need to convert to a list to ensure the data is json-serializable
  157. if is_list(data):
  158. return data
  159. return list(data)
  160. return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
  161. if is_union_type(stripped_type):
  162. # For union types we run the transformation against all subtypes to ensure that everything is transformed.
  163. #
  164. # TODO: there may be edge cases where the same normalized field name will transform to two different names
  165. # in different subtypes.
  166. for subtype in get_args(stripped_type):
  167. data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
  168. return data
  169. if isinstance(data, pydantic.BaseModel):
  170. return model_dump(data, exclude_unset=True, mode="json", exclude=getattr(data, "__api_exclude__", None))
  171. annotated_type = _get_annotated_type(annotation)
  172. if annotated_type is None:
  173. return data
  174. # ignore the first argument as it is the actual type
  175. annotations = get_args(annotated_type)[1:]
  176. for annotation in annotations:
  177. if isinstance(annotation, PropertyInfo) and annotation.format is not None:
  178. return _format_data(data, annotation.format, annotation.format_template)
  179. return data
  180. def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
  181. if isinstance(data, (date, datetime)):
  182. if format_ == "iso8601":
  183. return data.isoformat()
  184. if format_ == "custom" and format_template is not None:
  185. return data.strftime(format_template)
  186. if format_ == "base64" and is_base64_file_input(data):
  187. binary: str | bytes | None = None
  188. if isinstance(data, pathlib.Path):
  189. binary = data.read_bytes()
  190. elif isinstance(data, io.IOBase):
  191. binary = data.read()
  192. if isinstance(binary, str): # type: ignore[unreachable]
  193. binary = binary.encode()
  194. if not isinstance(binary, bytes):
  195. raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
  196. return base64.b64encode(binary).decode("ascii")
  197. return data
  198. def _transform_typeddict(
  199. data: Mapping[str, object],
  200. expected_type: type,
  201. ) -> Mapping[str, object]:
  202. result: dict[str, object] = {}
  203. annotations = get_type_hints(expected_type, include_extras=True)
  204. for key, value in data.items():
  205. if not is_given(value):
  206. # we don't need to include omitted values here as they'll
  207. # be stripped out before the request is sent anyway
  208. continue
  209. type_ = annotations.get(key)
  210. if type_ is None:
  211. # we do not have a type annotation for this field, leave it as is
  212. result[key] = value
  213. else:
  214. result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
  215. return result
  216. async def async_maybe_transform(
  217. data: object,
  218. expected_type: object,
  219. ) -> Any | None:
  220. """Wrapper over `async_transform()` that allows `None` to be passed.
  221. See `async_transform()` for more details.
  222. """
  223. if data is None:
  224. return None
  225. return await async_transform(data, expected_type)
  226. async def async_transform(
  227. data: _T,
  228. expected_type: object,
  229. ) -> _T:
  230. """Transform dictionaries based off of type information from the given type, for example:
  231. ```py
  232. class Params(TypedDict, total=False):
  233. card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
  234. transformed = transform({"card_id": "<my card ID>"}, Params)
  235. # {'cardID': '<my card ID>'}
  236. ```
  237. Any keys / data that does not have type information given will be included as is.
  238. It should be noted that the transformations that this function does are not represented in the type system.
  239. """
  240. transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
  241. return cast(_T, transformed)
  242. async def _async_transform_recursive(
  243. data: object,
  244. *,
  245. annotation: type,
  246. inner_type: type | None = None,
  247. ) -> object:
  248. """Transform the given data against the expected type.
  249. Args:
  250. annotation: The direct type annotation given to the particular piece of data.
  251. This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
  252. inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
  253. is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
  254. the list can be transformed using the metadata from the container type.
  255. Defaults to the same value as the `annotation` argument.
  256. """
  257. from .._compat import model_dump
  258. if inner_type is None:
  259. inner_type = annotation
  260. stripped_type = strip_annotated_type(inner_type)
  261. origin = get_origin(stripped_type) or stripped_type
  262. if is_typeddict(stripped_type) and is_mapping(data):
  263. return await _async_transform_typeddict(data, stripped_type)
  264. if origin == dict and is_mapping(data):
  265. items_type = get_args(stripped_type)[1]
  266. return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
  267. if (
  268. # List[T]
  269. (is_list_type(stripped_type) and is_list(data))
  270. # Iterable[T]
  271. or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
  272. # Sequence[T]
  273. or (is_sequence_type(stripped_type) and is_sequence(data) and not isinstance(data, str))
  274. ):
  275. # dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
  276. # intended as an iterable, so we don't transform it.
  277. if isinstance(data, dict):
  278. return cast(object, data)
  279. inner_type = extract_type_arg(stripped_type, 0)
  280. if _no_transform_needed(inner_type):
  281. # for some types there is no need to transform anything, so we can get a small
  282. # perf boost from skipping that work.
  283. #
  284. # but we still need to convert to a list to ensure the data is json-serializable
  285. if is_list(data):
  286. return data
  287. return list(data)
  288. return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
  289. if is_union_type(stripped_type):
  290. # For union types we run the transformation against all subtypes to ensure that everything is transformed.
  291. #
  292. # TODO: there may be edge cases where the same normalized field name will transform to two different names
  293. # in different subtypes.
  294. for subtype in get_args(stripped_type):
  295. data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
  296. return data
  297. if isinstance(data, pydantic.BaseModel):
  298. return model_dump(data, exclude_unset=True, mode="json")
  299. annotated_type = _get_annotated_type(annotation)
  300. if annotated_type is None:
  301. return data
  302. # ignore the first argument as it is the actual type
  303. annotations = get_args(annotated_type)[1:]
  304. for annotation in annotations:
  305. if isinstance(annotation, PropertyInfo) and annotation.format is not None:
  306. return await _async_format_data(data, annotation.format, annotation.format_template)
  307. return data
  308. async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
  309. if isinstance(data, (date, datetime)):
  310. if format_ == "iso8601":
  311. return data.isoformat()
  312. if format_ == "custom" and format_template is not None:
  313. return data.strftime(format_template)
  314. if format_ == "base64" and is_base64_file_input(data):
  315. binary: str | bytes | None = None
  316. if isinstance(data, pathlib.Path):
  317. binary = await anyio.Path(data).read_bytes()
  318. elif isinstance(data, io.IOBase):
  319. binary = data.read()
  320. if isinstance(binary, str): # type: ignore[unreachable]
  321. binary = binary.encode()
  322. if not isinstance(binary, bytes):
  323. raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
  324. return base64.b64encode(binary).decode("ascii")
  325. return data
  326. async def _async_transform_typeddict(
  327. data: Mapping[str, object],
  328. expected_type: type,
  329. ) -> Mapping[str, object]:
  330. result: dict[str, object] = {}
  331. annotations = get_type_hints(expected_type, include_extras=True)
  332. for key, value in data.items():
  333. if not is_given(value):
  334. # we don't need to include omitted values here as they'll
  335. # be stripped out before the request is sent anyway
  336. continue
  337. type_ = annotations.get(key)
  338. if type_ is None:
  339. # we do not have a type annotation for this field, leave it as is
  340. result[key] = value
  341. else:
  342. result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
  343. return result
  344. @lru_cache(maxsize=8096)
  345. def get_type_hints(
  346. obj: Any,
  347. globalns: dict[str, Any] | None = None,
  348. localns: Mapping[str, Any] | None = None,
  349. include_extras: bool = False,
  350. ) -> dict[str, Any]:
  351. return _get_type_hints(obj, globalns=globalns, localns=localns, include_extras=include_extras)