| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372 |
- from __future__ import annotations
- import inspect
- from types import TracebackType
- from typing import Any, List, Generic, Iterable, Awaitable, cast
- from typing_extensions import Self, Callable, Iterator, AsyncIterator
- from ._types import ParsedResponseSnapshot
- from ._events import (
- ResponseStreamEvent,
- ResponseTextDoneEvent,
- ResponseCompletedEvent,
- ResponseTextDeltaEvent,
- ResponseFunctionCallArgumentsDeltaEvent,
- )
- from ...._types import Omit, omit
- from ...._utils import is_given, consume_sync_iterator, consume_async_iterator
- from ...._models import build, construct_type_unchecked
- from ...._streaming import Stream, AsyncStream
- from ....types.responses import ParsedResponse, ResponseStreamEvent as RawResponseStreamEvent
- from ..._parsing._responses import TextFormatT, parse_text, parse_response
- from ....types.responses.tool_param import ToolParam
- from ....types.responses.parsed_response import (
- ParsedContent,
- ParsedResponseOutputMessage,
- ParsedResponseFunctionToolCall,
- )
- class ResponseStream(Generic[TextFormatT]):
- def __init__(
- self,
- *,
- raw_stream: Stream[RawResponseStreamEvent],
- text_format: type[TextFormatT] | Omit,
- input_tools: Iterable[ToolParam] | Omit,
- starting_after: int | None,
- ) -> None:
- self._raw_stream = raw_stream
- self._response = raw_stream.response
- self._iterator = self.__stream__()
- self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
- self._starting_after = starting_after
- def __next__(self) -> ResponseStreamEvent[TextFormatT]:
- return self._iterator.__next__()
- def __iter__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]:
- for item in self._iterator:
- yield item
- def __enter__(self) -> Self:
- return self
- def __stream__(self) -> Iterator[ResponseStreamEvent[TextFormatT]]:
- for sse_event in self._raw_stream:
- events_to_fire = self._state.handle_event(sse_event)
- for event in events_to_fire:
- if self._starting_after is None or event.sequence_number > self._starting_after:
- yield event
- def __exit__(
- self,
- exc_type: type[BaseException] | None,
- exc: BaseException | None,
- exc_tb: TracebackType | None,
- ) -> None:
- self.close()
- def close(self) -> None:
- """
- Close the response and release the connection.
- Automatically called if the response body is read to completion.
- """
- self._response.close()
- def get_final_response(self) -> ParsedResponse[TextFormatT]:
- """Waits until the stream has been read to completion and returns
- the accumulated `ParsedResponse` object.
- """
- self.until_done()
- response = self._state._completed_response
- if not response:
- raise RuntimeError("Didn't receive a `response.completed` event.")
- return response
- def until_done(self) -> Self:
- """Blocks until the stream has been consumed."""
- consume_sync_iterator(self)
- return self
- class ResponseStreamManager(Generic[TextFormatT]):
- def __init__(
- self,
- api_request: Callable[[], Stream[RawResponseStreamEvent]],
- *,
- text_format: type[TextFormatT] | Omit,
- input_tools: Iterable[ToolParam] | Omit,
- starting_after: int | None,
- ) -> None:
- self.__stream: ResponseStream[TextFormatT] | None = None
- self.__api_request = api_request
- self.__text_format = text_format
- self.__input_tools = input_tools
- self.__starting_after = starting_after
- def __enter__(self) -> ResponseStream[TextFormatT]:
- raw_stream = self.__api_request()
- self.__stream = ResponseStream(
- raw_stream=raw_stream,
- text_format=self.__text_format,
- input_tools=self.__input_tools,
- starting_after=self.__starting_after,
- )
- return self.__stream
- def __exit__(
- self,
- exc_type: type[BaseException] | None,
- exc: BaseException | None,
- exc_tb: TracebackType | None,
- ) -> None:
- if self.__stream is not None:
- self.__stream.close()
- class AsyncResponseStream(Generic[TextFormatT]):
- def __init__(
- self,
- *,
- raw_stream: AsyncStream[RawResponseStreamEvent],
- text_format: type[TextFormatT] | Omit,
- input_tools: Iterable[ToolParam] | Omit,
- starting_after: int | None,
- ) -> None:
- self._raw_stream = raw_stream
- self._response = raw_stream.response
- self._iterator = self.__stream__()
- self._state = ResponseStreamState(text_format=text_format, input_tools=input_tools)
- self._starting_after = starting_after
- async def __anext__(self) -> ResponseStreamEvent[TextFormatT]:
- return await self._iterator.__anext__()
- async def __aiter__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]:
- async for item in self._iterator:
- yield item
- async def __stream__(self) -> AsyncIterator[ResponseStreamEvent[TextFormatT]]:
- async for sse_event in self._raw_stream:
- events_to_fire = self._state.handle_event(sse_event)
- for event in events_to_fire:
- if self._starting_after is None or event.sequence_number > self._starting_after:
- yield event
- async def __aenter__(self) -> Self:
- return self
- async def __aexit__(
- self,
- exc_type: type[BaseException] | None,
- exc: BaseException | None,
- exc_tb: TracebackType | None,
- ) -> None:
- await self.close()
- async def close(self) -> None:
- """
- Close the response and release the connection.
- Automatically called if the response body is read to completion.
- """
- await self._response.aclose()
- async def get_final_response(self) -> ParsedResponse[TextFormatT]:
- """Waits until the stream has been read to completion and returns
- the accumulated `ParsedResponse` object.
- """
- await self.until_done()
- response = self._state._completed_response
- if not response:
- raise RuntimeError("Didn't receive a `response.completed` event.")
- return response
- async def until_done(self) -> Self:
- """Blocks until the stream has been consumed."""
- await consume_async_iterator(self)
- return self
- class AsyncResponseStreamManager(Generic[TextFormatT]):
- def __init__(
- self,
- api_request: Awaitable[AsyncStream[RawResponseStreamEvent]],
- *,
- text_format: type[TextFormatT] | Omit,
- input_tools: Iterable[ToolParam] | Omit,
- starting_after: int | None,
- ) -> None:
- self.__stream: AsyncResponseStream[TextFormatT] | None = None
- self.__api_request = api_request
- self.__text_format = text_format
- self.__input_tools = input_tools
- self.__starting_after = starting_after
- async def __aenter__(self) -> AsyncResponseStream[TextFormatT]:
- raw_stream = await self.__api_request
- self.__stream = AsyncResponseStream(
- raw_stream=raw_stream,
- text_format=self.__text_format,
- input_tools=self.__input_tools,
- starting_after=self.__starting_after,
- )
- return self.__stream
- async def __aexit__(
- self,
- exc_type: type[BaseException] | None,
- exc: BaseException | None,
- exc_tb: TracebackType | None,
- ) -> None:
- if self.__stream is not None:
- await self.__stream.close()
- class ResponseStreamState(Generic[TextFormatT]):
- def __init__(
- self,
- *,
- input_tools: Iterable[ToolParam] | Omit,
- text_format: type[TextFormatT] | Omit,
- ) -> None:
- self.__current_snapshot: ParsedResponseSnapshot | None = None
- self._completed_response: ParsedResponse[TextFormatT] | None = None
- self._input_tools = [tool for tool in input_tools] if is_given(input_tools) else []
- self._text_format = text_format
- self._rich_text_format: type | Omit = text_format if inspect.isclass(text_format) else omit
- def handle_event(self, event: RawResponseStreamEvent) -> List[ResponseStreamEvent[TextFormatT]]:
- self.__current_snapshot = snapshot = self.accumulate_event(event)
- events: List[ResponseStreamEvent[TextFormatT]] = []
- if event.type == "response.output_text.delta":
- output = snapshot.output[event.output_index]
- assert output.type == "message"
- content = output.content[event.content_index]
- assert content.type == "output_text"
- events.append(
- build(
- ResponseTextDeltaEvent,
- content_index=event.content_index,
- delta=event.delta,
- item_id=event.item_id,
- output_index=event.output_index,
- sequence_number=event.sequence_number,
- logprobs=event.logprobs,
- type="response.output_text.delta",
- snapshot=content.text,
- )
- )
- elif event.type == "response.output_text.done":
- output = snapshot.output[event.output_index]
- assert output.type == "message"
- content = output.content[event.content_index]
- assert content.type == "output_text"
- events.append(
- build(
- ResponseTextDoneEvent[TextFormatT],
- content_index=event.content_index,
- item_id=event.item_id,
- output_index=event.output_index,
- sequence_number=event.sequence_number,
- logprobs=event.logprobs,
- type="response.output_text.done",
- text=event.text,
- parsed=parse_text(event.text, text_format=self._text_format),
- )
- )
- elif event.type == "response.function_call_arguments.delta":
- output = snapshot.output[event.output_index]
- assert output.type == "function_call"
- events.append(
- build(
- ResponseFunctionCallArgumentsDeltaEvent,
- delta=event.delta,
- item_id=event.item_id,
- output_index=event.output_index,
- sequence_number=event.sequence_number,
- type="response.function_call_arguments.delta",
- snapshot=output.arguments,
- )
- )
- elif event.type == "response.completed":
- response = self._completed_response
- assert response is not None
- events.append(
- build(
- ResponseCompletedEvent,
- sequence_number=event.sequence_number,
- type="response.completed",
- response=response,
- )
- )
- else:
- events.append(event)
- return events
- def accumulate_event(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot:
- snapshot = self.__current_snapshot
- if snapshot is None:
- return self._create_initial_response(event)
- if event.type == "response.output_item.added":
- if event.item.type == "function_call":
- snapshot.output.append(
- construct_type_unchecked(
- type_=cast(Any, ParsedResponseFunctionToolCall), value=event.item.to_dict()
- )
- )
- elif event.item.type == "message":
- snapshot.output.append(
- construct_type_unchecked(type_=cast(Any, ParsedResponseOutputMessage), value=event.item.to_dict())
- )
- else:
- snapshot.output.append(event.item)
- elif event.type == "response.content_part.added":
- output = snapshot.output[event.output_index]
- if output.type == "message":
- output.content.append(
- construct_type_unchecked(type_=cast(Any, ParsedContent), value=event.part.to_dict())
- )
- elif event.type == "response.output_text.delta":
- output = snapshot.output[event.output_index]
- if output.type == "message":
- content = output.content[event.content_index]
- assert content.type == "output_text"
- content.text += event.delta
- elif event.type == "response.function_call_arguments.delta":
- output = snapshot.output[event.output_index]
- if output.type == "function_call":
- output.arguments += event.delta
- elif event.type == "response.completed":
- self._completed_response = parse_response(
- text_format=self._text_format,
- response=event.response,
- input_tools=self._input_tools,
- )
- return snapshot
- def _create_initial_response(self, event: RawResponseStreamEvent) -> ParsedResponseSnapshot:
- if event.type != "response.created":
- raise RuntimeError(f"Expected to have received `response.created` before `{event.type}`")
- return construct_type_unchecked(type_=ParsedResponseSnapshot, value=event.response.to_dict())
|