| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038 |
- from __future__ import annotations
- import asyncio
- from types import TracebackType
- from typing import TYPE_CHECKING, Any, Generic, TypeVar, Callable, Iterable, Iterator, cast
- from typing_extensions import Awaitable, AsyncIterable, AsyncIterator, assert_never
- import httpx
- from ..._utils import is_dict, is_list, consume_sync_iterator, consume_async_iterator
- from ..._compat import model_dump
- from ..._models import construct_type
- from ..._streaming import Stream, AsyncStream
- from ...types.beta import AssistantStreamEvent
- from ...types.beta.threads import (
- Run,
- Text,
- Message,
- ImageFile,
- TextDelta,
- MessageDelta,
- MessageContent,
- MessageContentDelta,
- )
- from ...types.beta.threads.runs import RunStep, ToolCall, RunStepDelta, ToolCallDelta
- class AssistantEventHandler:
- text_deltas: Iterable[str]
- """Iterator over just the text deltas in the stream.
- This corresponds to the `thread.message.delta` event
- in the API.
- ```py
- for text in stream.text_deltas:
- print(text, end="", flush=True)
- print()
- ```
- """
- def __init__(self) -> None:
- self._current_event: AssistantStreamEvent | None = None
- self._current_message_content_index: int | None = None
- self._current_message_content: MessageContent | None = None
- self._current_tool_call_index: int | None = None
- self._current_tool_call: ToolCall | None = None
- self.__current_run_step_id: str | None = None
- self.__current_run: Run | None = None
- self.__run_step_snapshots: dict[str, RunStep] = {}
- self.__message_snapshots: dict[str, Message] = {}
- self.__current_message_snapshot: Message | None = None
- self.text_deltas = self.__text_deltas__()
- self._iterator = self.__stream__()
- self.__stream: Stream[AssistantStreamEvent] | None = None
- def _init(self, stream: Stream[AssistantStreamEvent]) -> None:
- if self.__stream:
- raise RuntimeError(
- "A single event handler cannot be shared between multiple streams; You will need to construct a new event handler instance"
- )
- self.__stream = stream
- def __next__(self) -> AssistantStreamEvent:
- return self._iterator.__next__()
- def __iter__(self) -> Iterator[AssistantStreamEvent]:
- for item in self._iterator:
- yield item
- @property
- def current_event(self) -> AssistantStreamEvent | None:
- return self._current_event
- @property
- def current_run(self) -> Run | None:
- return self.__current_run
- @property
- def current_run_step_snapshot(self) -> RunStep | None:
- if not self.__current_run_step_id:
- return None
- return self.__run_step_snapshots[self.__current_run_step_id]
- @property
- def current_message_snapshot(self) -> Message | None:
- return self.__current_message_snapshot
- def close(self) -> None:
- """
- Close the response and release the connection.
- Automatically called when the context manager exits.
- """
- if self.__stream:
- self.__stream.close()
- def until_done(self) -> None:
- """Waits until the stream has been consumed"""
- consume_sync_iterator(self)
- def get_final_run(self) -> Run:
- """Wait for the stream to finish and returns the completed Run object"""
- self.until_done()
- if not self.__current_run:
- raise RuntimeError("No final run object found")
- return self.__current_run
- def get_final_run_steps(self) -> list[RunStep]:
- """Wait for the stream to finish and returns the steps taken in this run"""
- self.until_done()
- if not self.__run_step_snapshots:
- raise RuntimeError("No run steps found")
- return [step for step in self.__run_step_snapshots.values()]
- def get_final_messages(self) -> list[Message]:
- """Wait for the stream to finish and returns the messages emitted in this run"""
- self.until_done()
- if not self.__message_snapshots:
- raise RuntimeError("No messages found")
- return [message for message in self.__message_snapshots.values()]
- def __text_deltas__(self) -> Iterator[str]:
- for event in self:
- if event.event != "thread.message.delta":
- continue
- for content_delta in event.data.delta.content or []:
- if content_delta.type == "text" and content_delta.text and content_delta.text.value:
- yield content_delta.text.value
- # event handlers
- def on_end(self) -> None:
- """Fires when the stream has finished.
- This happens if the stream is read to completion
- or if an exception occurs during iteration.
- """
- def on_event(self, event: AssistantStreamEvent) -> None:
- """Callback that is fired for every Server-Sent-Event"""
- def on_run_step_created(self, run_step: RunStep) -> None:
- """Callback that is fired when a run step is created"""
- def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
- """Callback that is fired whenever a run step delta is returned from the API
- The first argument is just the delta as sent by the API and the second argument
- is the accumulated snapshot of the run step. For example, a tool calls event may
- look like this:
- # delta
- tool_calls=[
- RunStepDeltaToolCallsCodeInterpreter(
- index=0,
- type='code_interpreter',
- id=None,
- code_interpreter=CodeInterpreter(input=' sympy', outputs=None)
- )
- ]
- # snapshot
- tool_calls=[
- CodeToolCall(
- id='call_wKayJlcYV12NiadiZuJXxcfx',
- code_interpreter=CodeInterpreter(input='from sympy', outputs=[]),
- type='code_interpreter',
- index=0
- )
- ],
- """
- def on_run_step_done(self, run_step: RunStep) -> None:
- """Callback that is fired when a run step is completed"""
- def on_tool_call_created(self, tool_call: ToolCall) -> None:
- """Callback that is fired when a tool call is created"""
- def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
- """Callback that is fired when a tool call delta is encountered"""
- def on_tool_call_done(self, tool_call: ToolCall) -> None:
- """Callback that is fired when a tool call delta is encountered"""
- def on_exception(self, exception: Exception) -> None:
- """Fired whenever an exception happens during streaming"""
- def on_timeout(self) -> None:
- """Fires if the request times out"""
- def on_message_created(self, message: Message) -> None:
- """Callback that is fired when a message is created"""
- def on_message_delta(self, delta: MessageDelta, snapshot: Message) -> None:
- """Callback that is fired whenever a message delta is returned from the API
- The first argument is just the delta as sent by the API and the second argument
- is the accumulated snapshot of the message. For example, a text content event may
- look like this:
- # delta
- MessageDeltaText(
- index=0,
- type='text',
- text=Text(
- value=' Jane'
- ),
- )
- # snapshot
- MessageContentText(
- index=0,
- type='text',
- text=Text(
- value='Certainly, Jane'
- ),
- )
- """
- def on_message_done(self, message: Message) -> None:
- """Callback that is fired when a message is completed"""
- def on_text_created(self, text: Text) -> None:
- """Callback that is fired when a text content block is created"""
- def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
- """Callback that is fired whenever a text content delta is returned
- by the API.
- The first argument is just the delta as sent by the API and the second argument
- is the accumulated snapshot of the text. For example:
- on_text_delta(TextDelta(value="The"), Text(value="The")),
- on_text_delta(TextDelta(value=" solution"), Text(value="The solution")),
- on_text_delta(TextDelta(value=" to"), Text(value="The solution to")),
- on_text_delta(TextDelta(value=" the"), Text(value="The solution to the")),
- on_text_delta(TextDelta(value=" equation"), Text(value="The solution to the equation")),
- """
- def on_text_done(self, text: Text) -> None:
- """Callback that is fired when a text content block is finished"""
- def on_image_file_done(self, image_file: ImageFile) -> None:
- """Callback that is fired when an image file block is finished"""
- def _emit_sse_event(self, event: AssistantStreamEvent) -> None:
- self._current_event = event
- self.on_event(event)
- self.__current_message_snapshot, new_content = accumulate_event(
- event=event,
- current_message_snapshot=self.__current_message_snapshot,
- )
- if self.__current_message_snapshot is not None:
- self.__message_snapshots[self.__current_message_snapshot.id] = self.__current_message_snapshot
- accumulate_run_step(
- event=event,
- run_step_snapshots=self.__run_step_snapshots,
- )
- for content_delta in new_content:
- assert self.__current_message_snapshot is not None
- block = self.__current_message_snapshot.content[content_delta.index]
- if block.type == "text":
- self.on_text_created(block.text)
- if (
- event.event == "thread.run.completed"
- or event.event == "thread.run.cancelled"
- or event.event == "thread.run.expired"
- or event.event == "thread.run.failed"
- or event.event == "thread.run.requires_action"
- or event.event == "thread.run.incomplete"
- ):
- self.__current_run = event.data
- if self._current_tool_call:
- self.on_tool_call_done(self._current_tool_call)
- elif (
- event.event == "thread.run.created"
- or event.event == "thread.run.in_progress"
- or event.event == "thread.run.cancelling"
- or event.event == "thread.run.queued"
- ):
- self.__current_run = event.data
- elif event.event == "thread.message.created":
- self.on_message_created(event.data)
- elif event.event == "thread.message.delta":
- snapshot = self.__current_message_snapshot
- assert snapshot is not None
- message_delta = event.data.delta
- if message_delta.content is not None:
- for content_delta in message_delta.content:
- if content_delta.type == "text" and content_delta.text:
- snapshot_content = snapshot.content[content_delta.index]
- assert snapshot_content.type == "text"
- self.on_text_delta(content_delta.text, snapshot_content.text)
- # If the delta is for a new message content:
- # - emit on_text_done/on_image_file_done for the previous message content
- # - emit on_text_created/on_image_created for the new message content
- if content_delta.index != self._current_message_content_index:
- if self._current_message_content is not None:
- if self._current_message_content.type == "text":
- self.on_text_done(self._current_message_content.text)
- elif self._current_message_content.type == "image_file":
- self.on_image_file_done(self._current_message_content.image_file)
- self._current_message_content_index = content_delta.index
- self._current_message_content = snapshot.content[content_delta.index]
- # Update the current_message_content (delta event is correctly emitted already)
- self._current_message_content = snapshot.content[content_delta.index]
- self.on_message_delta(event.data.delta, snapshot)
- elif event.event == "thread.message.completed" or event.event == "thread.message.incomplete":
- self.__current_message_snapshot = event.data
- self.__message_snapshots[event.data.id] = event.data
- if self._current_message_content_index is not None:
- content = event.data.content[self._current_message_content_index]
- if content.type == "text":
- self.on_text_done(content.text)
- elif content.type == "image_file":
- self.on_image_file_done(content.image_file)
- self.on_message_done(event.data)
- elif event.event == "thread.run.step.created":
- self.__current_run_step_id = event.data.id
- self.on_run_step_created(event.data)
- elif event.event == "thread.run.step.in_progress":
- self.__current_run_step_id = event.data.id
- elif event.event == "thread.run.step.delta":
- step_snapshot = self.__run_step_snapshots[event.data.id]
- run_step_delta = event.data.delta
- if (
- run_step_delta.step_details
- and run_step_delta.step_details.type == "tool_calls"
- and run_step_delta.step_details.tool_calls is not None
- ):
- assert step_snapshot.step_details.type == "tool_calls"
- for tool_call_delta in run_step_delta.step_details.tool_calls:
- if tool_call_delta.index == self._current_tool_call_index:
- self.on_tool_call_delta(
- tool_call_delta,
- step_snapshot.step_details.tool_calls[tool_call_delta.index],
- )
- # If the delta is for a new tool call:
- # - emit on_tool_call_done for the previous tool_call
- # - emit on_tool_call_created for the new tool_call
- if tool_call_delta.index != self._current_tool_call_index:
- if self._current_tool_call is not None:
- self.on_tool_call_done(self._current_tool_call)
- self._current_tool_call_index = tool_call_delta.index
- self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
- self.on_tool_call_created(self._current_tool_call)
- # Update the current_tool_call (delta event is correctly emitted already)
- self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
- self.on_run_step_delta(
- event.data.delta,
- step_snapshot,
- )
- elif (
- event.event == "thread.run.step.completed"
- or event.event == "thread.run.step.cancelled"
- or event.event == "thread.run.step.expired"
- or event.event == "thread.run.step.failed"
- ):
- if self._current_tool_call:
- self.on_tool_call_done(self._current_tool_call)
- self.on_run_step_done(event.data)
- self.__current_run_step_id = None
- elif event.event == "thread.created" or event.event == "thread.message.in_progress" or event.event == "error":
- # currently no special handling
- ...
- else:
- # we only want to error at build-time
- if TYPE_CHECKING: # type: ignore[unreachable]
- assert_never(event)
- self._current_event = None
- def __stream__(self) -> Iterator[AssistantStreamEvent]:
- stream = self.__stream
- if not stream:
- raise RuntimeError("Stream has not been started yet")
- try:
- for event in stream:
- self._emit_sse_event(event)
- yield event
- except (httpx.TimeoutException, asyncio.TimeoutError) as exc:
- self.on_timeout()
- self.on_exception(exc)
- raise
- except Exception as exc:
- self.on_exception(exc)
- raise
- finally:
- self.on_end()
- AssistantEventHandlerT = TypeVar("AssistantEventHandlerT", bound=AssistantEventHandler)
- class AssistantStreamManager(Generic[AssistantEventHandlerT]):
- """Wrapper over AssistantStreamEventHandler that is returned by `.stream()`
- so that a context manager can be used.
- ```py
- with client.threads.create_and_run_stream(...) as stream:
- for event in stream:
- ...
- ```
- """
- def __init__(
- self,
- api_request: Callable[[], Stream[AssistantStreamEvent]],
- *,
- event_handler: AssistantEventHandlerT,
- ) -> None:
- self.__stream: Stream[AssistantStreamEvent] | None = None
- self.__event_handler = event_handler
- self.__api_request = api_request
- def __enter__(self) -> AssistantEventHandlerT:
- self.__stream = self.__api_request()
- self.__event_handler._init(self.__stream)
- return self.__event_handler
- def __exit__(
- self,
- exc_type: type[BaseException] | None,
- exc: BaseException | None,
- exc_tb: TracebackType | None,
- ) -> None:
- if self.__stream is not None:
- self.__stream.close()
- class AsyncAssistantEventHandler:
- text_deltas: AsyncIterable[str]
- """Iterator over just the text deltas in the stream.
- This corresponds to the `thread.message.delta` event
- in the API.
- ```py
- async for text in stream.text_deltas:
- print(text, end="", flush=True)
- print()
- ```
- """
- def __init__(self) -> None:
- self._current_event: AssistantStreamEvent | None = None
- self._current_message_content_index: int | None = None
- self._current_message_content: MessageContent | None = None
- self._current_tool_call_index: int | None = None
- self._current_tool_call: ToolCall | None = None
- self.__current_run_step_id: str | None = None
- self.__current_run: Run | None = None
- self.__run_step_snapshots: dict[str, RunStep] = {}
- self.__message_snapshots: dict[str, Message] = {}
- self.__current_message_snapshot: Message | None = None
- self.text_deltas = self.__text_deltas__()
- self._iterator = self.__stream__()
- self.__stream: AsyncStream[AssistantStreamEvent] | None = None
- def _init(self, stream: AsyncStream[AssistantStreamEvent]) -> None:
- if self.__stream:
- raise RuntimeError(
- "A single event handler cannot be shared between multiple streams; You will need to construct a new event handler instance"
- )
- self.__stream = stream
- async def __anext__(self) -> AssistantStreamEvent:
- return await self._iterator.__anext__()
- async def __aiter__(self) -> AsyncIterator[AssistantStreamEvent]:
- async for item in self._iterator:
- yield item
- async def close(self) -> None:
- """
- Close the response and release the connection.
- Automatically called when the context manager exits.
- """
- if self.__stream:
- await self.__stream.close()
- @property
- def current_event(self) -> AssistantStreamEvent | None:
- return self._current_event
- @property
- def current_run(self) -> Run | None:
- return self.__current_run
- @property
- def current_run_step_snapshot(self) -> RunStep | None:
- if not self.__current_run_step_id:
- return None
- return self.__run_step_snapshots[self.__current_run_step_id]
- @property
- def current_message_snapshot(self) -> Message | None:
- return self.__current_message_snapshot
- async def until_done(self) -> None:
- """Waits until the stream has been consumed"""
- await consume_async_iterator(self)
- async def get_final_run(self) -> Run:
- """Wait for the stream to finish and returns the completed Run object"""
- await self.until_done()
- if not self.__current_run:
- raise RuntimeError("No final run object found")
- return self.__current_run
- async def get_final_run_steps(self) -> list[RunStep]:
- """Wait for the stream to finish and returns the steps taken in this run"""
- await self.until_done()
- if not self.__run_step_snapshots:
- raise RuntimeError("No run steps found")
- return [step for step in self.__run_step_snapshots.values()]
- async def get_final_messages(self) -> list[Message]:
- """Wait for the stream to finish and returns the messages emitted in this run"""
- await self.until_done()
- if not self.__message_snapshots:
- raise RuntimeError("No messages found")
- return [message for message in self.__message_snapshots.values()]
- async def __text_deltas__(self) -> AsyncIterator[str]:
- async for event in self:
- if event.event != "thread.message.delta":
- continue
- for content_delta in event.data.delta.content or []:
- if content_delta.type == "text" and content_delta.text and content_delta.text.value:
- yield content_delta.text.value
- # event handlers
- async def on_end(self) -> None:
- """Fires when the stream has finished.
- This happens if the stream is read to completion
- or if an exception occurs during iteration.
- """
- async def on_event(self, event: AssistantStreamEvent) -> None:
- """Callback that is fired for every Server-Sent-Event"""
- async def on_run_step_created(self, run_step: RunStep) -> None:
- """Callback that is fired when a run step is created"""
- async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
- """Callback that is fired whenever a run step delta is returned from the API
- The first argument is just the delta as sent by the API and the second argument
- is the accumulated snapshot of the run step. For example, a tool calls event may
- look like this:
- # delta
- tool_calls=[
- RunStepDeltaToolCallsCodeInterpreter(
- index=0,
- type='code_interpreter',
- id=None,
- code_interpreter=CodeInterpreter(input=' sympy', outputs=None)
- )
- ]
- # snapshot
- tool_calls=[
- CodeToolCall(
- id='call_wKayJlcYV12NiadiZuJXxcfx',
- code_interpreter=CodeInterpreter(input='from sympy', outputs=[]),
- type='code_interpreter',
- index=0
- )
- ],
- """
- async def on_run_step_done(self, run_step: RunStep) -> None:
- """Callback that is fired when a run step is completed"""
- async def on_tool_call_created(self, tool_call: ToolCall) -> None:
- """Callback that is fired when a tool call is created"""
- async def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
- """Callback that is fired when a tool call delta is encountered"""
- async def on_tool_call_done(self, tool_call: ToolCall) -> None:
- """Callback that is fired when a tool call delta is encountered"""
- async def on_exception(self, exception: Exception) -> None:
- """Fired whenever an exception happens during streaming"""
- async def on_timeout(self) -> None:
- """Fires if the request times out"""
- async def on_message_created(self, message: Message) -> None:
- """Callback that is fired when a message is created"""
- async def on_message_delta(self, delta: MessageDelta, snapshot: Message) -> None:
- """Callback that is fired whenever a message delta is returned from the API
- The first argument is just the delta as sent by the API and the second argument
- is the accumulated snapshot of the message. For example, a text content event may
- look like this:
- # delta
- MessageDeltaText(
- index=0,
- type='text',
- text=Text(
- value=' Jane'
- ),
- )
- # snapshot
- MessageContentText(
- index=0,
- type='text',
- text=Text(
- value='Certainly, Jane'
- ),
- )
- """
- async def on_message_done(self, message: Message) -> None:
- """Callback that is fired when a message is completed"""
- async def on_text_created(self, text: Text) -> None:
- """Callback that is fired when a text content block is created"""
- async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
- """Callback that is fired whenever a text content delta is returned
- by the API.
- The first argument is just the delta as sent by the API and the second argument
- is the accumulated snapshot of the text. For example:
- on_text_delta(TextDelta(value="The"), Text(value="The")),
- on_text_delta(TextDelta(value=" solution"), Text(value="The solution")),
- on_text_delta(TextDelta(value=" to"), Text(value="The solution to")),
- on_text_delta(TextDelta(value=" the"), Text(value="The solution to the")),
- on_text_delta(TextDelta(value=" equation"), Text(value="The solution to the equivalent")),
- """
- async def on_text_done(self, text: Text) -> None:
- """Callback that is fired when a text content block is finished"""
- async def on_image_file_done(self, image_file: ImageFile) -> None:
- """Callback that is fired when an image file block is finished"""
- async def _emit_sse_event(self, event: AssistantStreamEvent) -> None:
- self._current_event = event
- await self.on_event(event)
- self.__current_message_snapshot, new_content = accumulate_event(
- event=event,
- current_message_snapshot=self.__current_message_snapshot,
- )
- if self.__current_message_snapshot is not None:
- self.__message_snapshots[self.__current_message_snapshot.id] = self.__current_message_snapshot
- accumulate_run_step(
- event=event,
- run_step_snapshots=self.__run_step_snapshots,
- )
- for content_delta in new_content:
- assert self.__current_message_snapshot is not None
- block = self.__current_message_snapshot.content[content_delta.index]
- if block.type == "text":
- await self.on_text_created(block.text)
- if (
- event.event == "thread.run.completed"
- or event.event == "thread.run.cancelled"
- or event.event == "thread.run.expired"
- or event.event == "thread.run.failed"
- or event.event == "thread.run.requires_action"
- or event.event == "thread.run.incomplete"
- ):
- self.__current_run = event.data
- if self._current_tool_call:
- await self.on_tool_call_done(self._current_tool_call)
- elif (
- event.event == "thread.run.created"
- or event.event == "thread.run.in_progress"
- or event.event == "thread.run.cancelling"
- or event.event == "thread.run.queued"
- ):
- self.__current_run = event.data
- elif event.event == "thread.message.created":
- await self.on_message_created(event.data)
- elif event.event == "thread.message.delta":
- snapshot = self.__current_message_snapshot
- assert snapshot is not None
- message_delta = event.data.delta
- if message_delta.content is not None:
- for content_delta in message_delta.content:
- if content_delta.type == "text" and content_delta.text:
- snapshot_content = snapshot.content[content_delta.index]
- assert snapshot_content.type == "text"
- await self.on_text_delta(content_delta.text, snapshot_content.text)
- # If the delta is for a new message content:
- # - emit on_text_done/on_image_file_done for the previous message content
- # - emit on_text_created/on_image_created for the new message content
- if content_delta.index != self._current_message_content_index:
- if self._current_message_content is not None:
- if self._current_message_content.type == "text":
- await self.on_text_done(self._current_message_content.text)
- elif self._current_message_content.type == "image_file":
- await self.on_image_file_done(self._current_message_content.image_file)
- self._current_message_content_index = content_delta.index
- self._current_message_content = snapshot.content[content_delta.index]
- # Update the current_message_content (delta event is correctly emitted already)
- self._current_message_content = snapshot.content[content_delta.index]
- await self.on_message_delta(event.data.delta, snapshot)
- elif event.event == "thread.message.completed" or event.event == "thread.message.incomplete":
- self.__current_message_snapshot = event.data
- self.__message_snapshots[event.data.id] = event.data
- if self._current_message_content_index is not None:
- content = event.data.content[self._current_message_content_index]
- if content.type == "text":
- await self.on_text_done(content.text)
- elif content.type == "image_file":
- await self.on_image_file_done(content.image_file)
- await self.on_message_done(event.data)
- elif event.event == "thread.run.step.created":
- self.__current_run_step_id = event.data.id
- await self.on_run_step_created(event.data)
- elif event.event == "thread.run.step.in_progress":
- self.__current_run_step_id = event.data.id
- elif event.event == "thread.run.step.delta":
- step_snapshot = self.__run_step_snapshots[event.data.id]
- run_step_delta = event.data.delta
- if (
- run_step_delta.step_details
- and run_step_delta.step_details.type == "tool_calls"
- and run_step_delta.step_details.tool_calls is not None
- ):
- assert step_snapshot.step_details.type == "tool_calls"
- for tool_call_delta in run_step_delta.step_details.tool_calls:
- if tool_call_delta.index == self._current_tool_call_index:
- await self.on_tool_call_delta(
- tool_call_delta,
- step_snapshot.step_details.tool_calls[tool_call_delta.index],
- )
- # If the delta is for a new tool call:
- # - emit on_tool_call_done for the previous tool_call
- # - emit on_tool_call_created for the new tool_call
- if tool_call_delta.index != self._current_tool_call_index:
- if self._current_tool_call is not None:
- await self.on_tool_call_done(self._current_tool_call)
- self._current_tool_call_index = tool_call_delta.index
- self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
- await self.on_tool_call_created(self._current_tool_call)
- # Update the current_tool_call (delta event is correctly emitted already)
- self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
- await self.on_run_step_delta(
- event.data.delta,
- step_snapshot,
- )
- elif (
- event.event == "thread.run.step.completed"
- or event.event == "thread.run.step.cancelled"
- or event.event == "thread.run.step.expired"
- or event.event == "thread.run.step.failed"
- ):
- if self._current_tool_call:
- await self.on_tool_call_done(self._current_tool_call)
- await self.on_run_step_done(event.data)
- self.__current_run_step_id = None
- elif event.event == "thread.created" or event.event == "thread.message.in_progress" or event.event == "error":
- # currently no special handling
- ...
- else:
- # we only want to error at build-time
- if TYPE_CHECKING: # type: ignore[unreachable]
- assert_never(event)
- self._current_event = None
- async def __stream__(self) -> AsyncIterator[AssistantStreamEvent]:
- stream = self.__stream
- if not stream:
- raise RuntimeError("Stream has not been started yet")
- try:
- async for event in stream:
- await self._emit_sse_event(event)
- yield event
- except (httpx.TimeoutException, asyncio.TimeoutError) as exc:
- await self.on_timeout()
- await self.on_exception(exc)
- raise
- except Exception as exc:
- await self.on_exception(exc)
- raise
- finally:
- await self.on_end()
- AsyncAssistantEventHandlerT = TypeVar("AsyncAssistantEventHandlerT", bound=AsyncAssistantEventHandler)
- class AsyncAssistantStreamManager(Generic[AsyncAssistantEventHandlerT]):
- """Wrapper over AsyncAssistantStreamEventHandler that is returned by `.stream()`
- so that an async context manager can be used without `await`ing the
- original client call.
- ```py
- async with client.threads.create_and_run_stream(...) as stream:
- async for event in stream:
- ...
- ```
- """
- def __init__(
- self,
- api_request: Awaitable[AsyncStream[AssistantStreamEvent]],
- *,
- event_handler: AsyncAssistantEventHandlerT,
- ) -> None:
- self.__stream: AsyncStream[AssistantStreamEvent] | None = None
- self.__event_handler = event_handler
- self.__api_request = api_request
- async def __aenter__(self) -> AsyncAssistantEventHandlerT:
- self.__stream = await self.__api_request
- self.__event_handler._init(self.__stream)
- return self.__event_handler
- async def __aexit__(
- self,
- exc_type: type[BaseException] | None,
- exc: BaseException | None,
- exc_tb: TracebackType | None,
- ) -> None:
- if self.__stream is not None:
- await self.__stream.close()
- def accumulate_run_step(
- *,
- event: AssistantStreamEvent,
- run_step_snapshots: dict[str, RunStep],
- ) -> None:
- if event.event == "thread.run.step.created":
- run_step_snapshots[event.data.id] = event.data
- return
- if event.event == "thread.run.step.delta":
- data = event.data
- snapshot = run_step_snapshots[data.id]
- if data.delta:
- merged = accumulate_delta(
- cast(
- "dict[object, object]",
- model_dump(snapshot, exclude_unset=True, warnings=False),
- ),
- cast(
- "dict[object, object]",
- model_dump(data.delta, exclude_unset=True, warnings=False),
- ),
- )
- run_step_snapshots[snapshot.id] = cast(RunStep, construct_type(type_=RunStep, value=merged))
- return None
- def accumulate_event(
- *,
- event: AssistantStreamEvent,
- current_message_snapshot: Message | None,
- ) -> tuple[Message | None, list[MessageContentDelta]]:
- """Returns a tuple of message snapshot and newly created text message deltas"""
- if event.event == "thread.message.created":
- return event.data, []
- new_content: list[MessageContentDelta] = []
- if event.event != "thread.message.delta":
- return current_message_snapshot, []
- if not current_message_snapshot:
- raise RuntimeError("Encountered a message delta with no previous snapshot")
- data = event.data
- if data.delta.content:
- for content_delta in data.delta.content:
- try:
- block = current_message_snapshot.content[content_delta.index]
- except IndexError:
- current_message_snapshot.content.insert(
- content_delta.index,
- cast(
- MessageContent,
- construct_type(
- # mypy doesn't allow Content for some reason
- type_=cast(Any, MessageContent),
- value=model_dump(content_delta, exclude_unset=True, warnings=False),
- ),
- ),
- )
- new_content.append(content_delta)
- else:
- merged = accumulate_delta(
- cast(
- "dict[object, object]",
- model_dump(block, exclude_unset=True, warnings=False),
- ),
- cast(
- "dict[object, object]",
- model_dump(content_delta, exclude_unset=True, warnings=False),
- ),
- )
- current_message_snapshot.content[content_delta.index] = cast(
- MessageContent,
- construct_type(
- # mypy doesn't allow Content for some reason
- type_=cast(Any, MessageContent),
- value=merged,
- ),
- )
- return current_message_snapshot, new_content
- def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]:
- for key, delta_value in delta.items():
- if key not in acc:
- acc[key] = delta_value
- continue
- acc_value = acc[key]
- if acc_value is None:
- acc[key] = delta_value
- continue
- # the `index` property is used in arrays of objects so it should
- # not be accumulated like other values e.g.
- # [{'foo': 'bar', 'index': 0}]
- #
- # the same applies to `type` properties as they're used for
- # discriminated unions
- if key == "index" or key == "type":
- acc[key] = delta_value
- continue
- if isinstance(acc_value, str) and isinstance(delta_value, str):
- acc_value += delta_value
- elif isinstance(acc_value, (int, float)) and isinstance(delta_value, (int, float)):
- acc_value += delta_value
- elif is_dict(acc_value) and is_dict(delta_value):
- acc_value = accumulate_delta(acc_value, delta_value)
- elif is_list(acc_value) and is_list(delta_value):
- # for lists of non-dictionary items we'll only ever get new entries
- # in the array, existing entries will never be changed
- if all(isinstance(x, (str, int, float)) for x in acc_value):
- acc_value.extend(delta_value)
- continue
- for delta_entry in delta_value:
- if not is_dict(delta_entry):
- raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}")
- try:
- index = delta_entry["index"]
- except KeyError as exc:
- raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc
- if not isinstance(index, int):
- raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}")
- try:
- acc_entry = acc_value[index]
- except IndexError:
- acc_value.insert(index, delta_entry)
- else:
- if not is_dict(acc_entry):
- raise TypeError("not handled yet")
- acc_value[index] = accumulate_delta(acc_entry, delta_entry)
- acc[key] = acc_value
- return acc
|