_responses.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. from __future__ import annotations
  2. import inspect
  3. from types import TracebackType
  4. from typing import Any, List, Generic, Iterable, Awaitable, cast
  5. from typing_extensions import Self, Callable, Iterator, AsyncIterator
  6. from ._types import ParsedResponseSnapshot
  7. from ._events import (
  8. ResponseStreamEvent,
  9. ResponseTextDoneEvent,
  10. ResponseCompletedEvent,
  11. ResponseTextDeltaEvent,
  12. ResponseFunctionCallArgumentsDeltaEvent,
  13. )
  14. from ...._types import Omit, omit
  15. from ...._utils import is_given, consume_sync_iterator, consume_async_iterator
  16. from ...._models import build, construct_type_unchecked
  17. from ...._streaming import Stream, AsyncStream
  18. from ....types.responses import ParsedResponse, ResponseStreamEvent as RawResponseStreamEvent
  19. from ..._parsing._responses import TextFormatT, parse_text, parse_response
  20. from ....types.responses.tool_param import ToolParam
  21. from ....types.responses.parsed_response import (
  22. ParsedContent,
  23. ParsedResponseOutputMessage,
  24. ParsedResponseFunctionToolCall,
  25. )
  26. class ResponseStream(Generic[TextFormatT]):
  27. def __init__(
  28. self,
  29. *,
  30. raw_stream: Stream[RawResponseStreamEvent],
  31. text_format: type[TextFormatT] | Omit,
  32. input_tools: Iterable[ToolParam] | Omit,
  33. starting_after: int | None,
  34. ) -> None:
  35. self._raw_stream = raw_stream
  36. self._response = raw_stream.response
  37. self._iterator = self.__stream__()
  38. self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
  39. self._starting_after = starting_after
  40. def __next__(self) -> ResponseStreamEvent[TextFormatT]:
  41. return self._iterator.__next__()
  42. def __iter__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]:
  43. for item in self._iterator:
  44. yield item
  45. def __enter__(self) -> Self:
  46. return self
  47. def __stream__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]:
  48. for sse_event in self._raw_stream:
  49. events_to_fire = self._state.handle_event(sse_event)
  50. for event in events_to_fire:
  51. if self._starting_after is None or event.sequence_number > self._starting_after:
  52. yield event
  53. def __exit__(
  54. self,
  55. exc_type: type[BaseException] | None,
  56. exc: BaseException | None,
  57. exc_tb: TracebackType | None,
  58. ) -> None:
  59. self.close()
  60. def close(self) -> None:
  61. """
  62. Close the response and release the connection.
  63. Automatically called if the response body is read to completion.
  64. """
  65. self._response.close()
  66. def get_final_response(self) -> ParsedResponse[TextFormatT]:
  67. """Waits until the stream has been read to completion and returns
  68. the accumulated `ParsedResponse` object.
  69. """
  70. self.until_done()
  71. response = self._state._completed_response
  72. if not response:
  73. raise RuntimeError("Didn't receive a `response.completed` event.")
  74. return response
  75. def until_done(self) -> Self:
  76. """Blocks until the stream has been consumed."""
  77. consume_sync_iterator(self)
  78. return self
  79. class ResponseStreamManager(Generic[TextFormatT]):
  80. def __init__(
  81. self,
  82. api_request: Callable[[], Stream[RawResponseStreamEvent]],
  83. *,
  84. text_format: type[TextFormatT] | Omit,
  85. input_tools: Iterable[ToolParam] | Omit,
  86. starting_after: int | None,
  87. ) -> None:
  88. self.__stream: ResponseStream[TextFormatT] | None = None
  89. self.__api_request = api_request
  90. self.__text_format = text_format
  91. self.__input_tools = input_tools
  92. self.__starting_after = starting_after
  93. def __enter__(self) -> ResponseStream[TextFormatT]:
  94. raw_stream = self.__api_request()
  95. self.__stream = ResponseStream(
  96. raw_stream=raw_stream,
  97. text_format=self.__text_format,
  98. input_tools=self.__input_tools,
  99. starting_after=self.__starting_after,
  100. )
  101. return self.__stream
  102. def __exit__(
  103. self,
  104. exc_type: type[BaseException] | None,
  105. exc: BaseException | None,
  106. exc_tb: TracebackType | None,
  107. ) -> None:
  108. if self.__stream is not None:
  109. self.__stream.close()
  110. class AsyncResponseStream(Generic[TextFormatT]):
  111. def __init__(
  112. self,
  113. *,
  114. raw_stream: AsyncStream[RawResponseStreamEvent],
  115. text_format: type[TextFormatT] | Omit,
  116. input_tools: Iterable[ToolParam] | Omit,
  117. starting_after: int | None,
  118. ) -> None:
  119. self._raw_stream = raw_stream
  120. self._response = raw_stream.response
  121. self._iterator = self.__stream__()
  122. self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
  123. self._starting_after = starting_after
  124. async def __anext__(self) -> ResponseStreamEvent[TextFormatT]:
  125. return await self._iterator.__anext__()
  126. async def __aiter__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]:
  127. async for item in self._iterator:
  128. yield item
  129. async def __stream__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]:
  130. async for sse_event in self._raw_stream:
  131. events_to_fire = self._state.handle_event(sse_event)
  132. for event in events_to_fire:
  133. if self._starting_after is None or event.sequence_number > self._starting_after:
  134. yield event
  135. async def __aenter__(self) -> Self:
  136. return self
  137. async def __aexit__(
  138. self,
  139. exc_type: type[BaseException] | None,
  140. exc: BaseException | None,
  141. exc_tb: TracebackType | None,
  142. ) -> None:
  143. await self.close()
  144. async def close(self) -> None:
  145. """
  146. Close the response and release the connection.
  147. Automatically called if the response body is read to completion.
  148. """
  149. await self._response.aclose()
  150. async def get_final_response(self) -> ParsedResponse[TextFormatT]:
  151. """Waits until the stream has been read to completion and returns
  152. the accumulated `ParsedResponse` object.
  153. """
  154. await self.until_done()
  155. response = self._state._completed_response
  156. if not response:
  157. raise RuntimeError("Didn't receive a `response.completed` event.")
  158. return response
  159. async def until_done(self) -> Self:
  160. """Blocks until the stream has been consumed."""
  161. await consume_async_iterator(self)
  162. return self
  163. class AsyncResponseStreamManager(Generic[TextFormatT]):
  164. def __init__(
  165. self,
  166. api_request: Awaitable[AsyncStream[RawResponseStreamEvent]],
  167. *,
  168. text_format: type[TextFormatT] | Omit,
  169. input_tools: Iterable[ToolParam] | Omit,
  170. starting_after: int | None,
  171. ) -> None:
  172. self.__stream: AsyncResponseStream[TextFormatT] | None = None
  173. self.__api_request = api_request
  174. self.__text_format = text_format
  175. self.__input_tools = input_tools
  176. self.__starting_after = starting_after
  177. async def __aenter__(self) -> AsyncResponseStream[TextFormatT]:
  178. raw_stream = await self.__api_request
  179. self.__stream = AsyncResponseStream(
  180. raw_stream=raw_stream,
  181. text_format=self.__text_format,
  182. input_tools=self.__input_tools,
  183. starting_after=self.__starting_after,
  184. )
  185. return self.__stream
  186. async def __aexit__(
  187. self,
  188. exc_type: type[BaseException] | None,
  189. exc: BaseException | None,
  190. exc_tb: TracebackType | None,
  191. ) -> None:
  192. if self.__stream is not None:
  193. await self.__stream.close()
  194. class ResponseStreamState(Generic[TextFormatT]):
  195. def __init__(
  196. self,
  197. *,
  198. input_tools: Iterable[ToolParam] | Omit,
  199. text_format: type[TextFormatT] | Omit,
  200. ) -> None:
  201. self.__current_snapshot: ParsedResponseSnapshot | None = None
  202. self._completed_response: ParsedResponse[TextFormatT] | None = None
  203. self._input_tools = [tool for tool in input_tools] if is_given(input_tools) else []
  204. self._text_format = text_format
  205. self._rich_text_format: type | Omit = text_format if inspect.isclass(text_format) else omit
  206. def handle_event(self, event: RawResponseStreamEvent) -> List[ResponseStreamEvent[TextFormatT]]:
  207. self.__current_snapshot = snapshot = self.accumulate_event(event)
  208. events: List[ResponseStreamEvent[TextFormatT]] = []
  209. if event.type == "response.output_text.delta":
  210. output = snapshot.output[event.output_index]
  211. assert output.type == "message"
  212. content = output.content[event.content_index]
  213. assert content.type == "output_text"
  214. events.append(
  215. build(
  216. ResponseTextDeltaEvent,
  217. content_index=event.content_index,
  218. delta=event.delta,
  219. item_id=event.item_id,
  220. output_index=event.output_index,
  221. sequence_number=event.sequence_number,
  222. logprobs=event.logprobs,
  223. type="response.output_text.delta",
  224. snapshot=content.text,
  225. )
  226. )
  227. elif event.type == "response.output_text.done":
  228. output = snapshot.output[event.output_index]
  229. assert output.type == "message"
  230. content = output.content[event.content_index]
  231. assert content.type == "output_text"
  232. events.append(
  233. build(
  234. ResponseTextDoneEvent[TextFormatT],
  235. content_index=event.content_index,
  236. item_id=event.item_id,
  237. output_index=event.output_index,
  238. sequence_number=event.sequence_number,
  239. logprobs=event.logprobs,
  240. type="response.output_text.done",
  241. text=event.text,
  242. parsed=parse_text(event.text, text_format=self._text_format),
  243. )
  244. )
  245. elif event.type == "response.function_call_arguments.delta":
  246. output = snapshot.output[event.output_index]
  247. assert output.type == "function_call"
  248. events.append(
  249. build(
  250. ResponseFunctionCallArgumentsDeltaEvent,
  251. delta=event.delta,
  252. item_id=event.item_id,
  253. output_index=event.output_index,
  254. sequence_number=event.sequence_number,
  255. type="response.function_call_arguments.delta",
  256. snapshot=output.arguments,
  257. )
  258. )
  259. elif event.type == "response.completed":
  260. response = self._completed_response
  261. assert response is not None
  262. events.append(
  263. build(
  264. ResponseCompletedEvent,
  265. sequence_number=event.sequence_number,
  266. type="response.completed",
  267. response=response,
  268. )
  269. )
  270. else:
  271. events.append(event)
  272. return events
  273. def accumulate_event(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot:
  274. snapshot = self.__current_snapshot
  275. if snapshot is None:
  276. return self._create_initial_response(event)
  277. if event.type == "response.output_item.added":
  278. if event.item.type == "function_call":
  279. snapshot.output.append(
  280. construct_type_unchecked(
  281. type_=cast(Any, ParsedResponseFunctionToolCall), value=event.item.to_dict()
  282. )
  283. )
  284. elif event.item.type == "message":
  285. snapshot.output.append(
  286. construct_type_unchecked(type_=cast(Any, ParsedResponseOutputMessage), value=event.item.to_dict())
  287. )
  288. else:
  289. snapshot.output.append(event.item)
  290. elif event.type == "response.content_part.added":
  291. output = snapshot.output[event.output_index]
  292. if output.type == "message":
  293. output.content.append(
  294. construct_type_unchecked(type_=cast(Any, ParsedContent), value=event.part.to_dict())
  295. )
  296. elif event.type == "response.output_text.delta":
  297. output = snapshot.output[event.output_index]
  298. if output.type == "message":
  299. content = output.content[event.content_index]
  300. assert content.type == "output_text"
  301. content.text += event.delta
  302. elif event.type == "response.function_call_arguments.delta":
  303. output = snapshot.output[event.output_index]
  304. if output.type == "function_call":
  305. output.arguments += event.delta
  306. elif event.type == "response.completed":
  307. self._completed_response = parse_response(
  308. text_format=self._text_format,
  309. response=event.response,
  310. input_tools=self._input_tools,
  311. )
  312. return snapshot
  313. def _create_initial_response(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot:
  314. if event.type != "response.created":
  315. raise RuntimeError(f"Expected to have received `response.created` before `{event.type}`")
  316. return construct_type_unchecked(type_=ParsedResponseSnapshot, value=event.response.to_dict())