_assistants.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038
  1. from __future__ import annotations
  2. import asyncio
  3. from types import TracebackType
  4. from typing import TYPE_CHECKING, Any, Generic, TypeVar, Callable, Iterable, Iterator, cast
  5. from typing_extensions import Awaitable, AsyncIterable, AsyncIterator, assert_never
  6. import httpx
  7. from ..._utils import is_dict, is_list, consume_sync_iterator, consume_async_iterator
  8. from ..._compat import model_dump
  9. from ..._models import construct_type
  10. from ..._streaming import Stream, AsyncStream
  11. from ...types.beta import AssistantStreamEvent
  12. from ...types.beta.threads import (
  13. Run,
  14. Text,
  15. Message,
  16. ImageFile,
  17. TextDelta,
  18. MessageDelta,
  19. MessageContent,
  20. MessageContentDelta,
  21. )
  22. from ...types.beta.threads.runs import RunStep, ToolCall, RunStepDelta, ToolCallDelta
  23. class AssistantEventHandler:
  24. text_deltas: Iterable[str]
  25. """Iterator over just the text deltas in the stream.
  26. This corresponds to the `thread.message.delta` event
  27. in the API.
  28. ```py
  29. for text in stream.text_deltas:
  30. print(text, end="", flush=True)
  31. print()
  32. ```
  33. """
  34. def __init__(self) -> None:
  35. self._current_event: AssistantStreamEvent | None = None
  36. self._current_message_content_index: int | None = None
  37. self._current_message_content: MessageContent | None = None
  38. self._current_tool_call_index: int | None = None
  39. self._current_tool_call: ToolCall | None = None
  40. self.__current_run_step_id: str | None = None
  41. self.__current_run: Run | None = None
  42. self.__run_step_snapshots: dict[str, RunStep] = {}
  43. self.__message_snapshots: dict[str, Message] = {}
  44. self.__current_message_snapshot: Message | None = None
  45. self.text_deltas = self.__text_deltas__()
  46. self._iterator = self.__stream__()
  47. self.__stream: Stream[AssistantStreamEvent] | None = None
  48. def _init(self, stream: Stream[AssistantStreamEvent]) -> None:
  49. if self.__stream:
  50. raise RuntimeError(
  51. "A single event handler cannot be shared between multiple streams; You will need to construct a new event handler instance"
  52. )
  53. self.__stream = stream
  54. def __next__(self) -> AssistantStreamEvent:
  55. return self._iterator.__next__()
  56. def __iter__(self) -> Iterator[AssistantStreamEvent]:
  57. for item in self._iterator:
  58. yield item
  59. @property
  60. def current_event(self) -> AssistantStreamEvent | None:
  61. return self._current_event
  62. @property
  63. def current_run(self) -> Run | None:
  64. return self.__current_run
  65. @property
  66. def current_run_step_snapshot(self) -> RunStep | None:
  67. if not self.__current_run_step_id:
  68. return None
  69. return self.__run_step_snapshots[self.__current_run_step_id]
  70. @property
  71. def current_message_snapshot(self) -> Message | None:
  72. return self.__current_message_snapshot
  73. def close(self) -> None:
  74. """
  75. Close the response and release the connection.
  76. Automatically called when the context manager exits.
  77. """
  78. if self.__stream:
  79. self.__stream.close()
  80. def until_done(self) -> None:
  81. """Waits until the stream has been consumed"""
  82. consume_sync_iterator(self)
  83. def get_final_run(self) -> Run:
  84. """Wait for the stream to finish and returns the completed Run object"""
  85. self.until_done()
  86. if not self.__current_run:
  87. raise RuntimeError("No final run object found")
  88. return self.__current_run
  89. def get_final_run_steps(self) -> list[RunStep]:
  90. """Wait for the stream to finish and returns the steps taken in this run"""
  91. self.until_done()
  92. if not self.__run_step_snapshots:
  93. raise RuntimeError("No run steps found")
  94. return [step for step in self.__run_step_snapshots.values()]
  95. def get_final_messages(self) -> list[Message]:
  96. """Wait for the stream to finish and returns the messages emitted in this run"""
  97. self.until_done()
  98. if not self.__message_snapshots:
  99. raise RuntimeError("No messages found")
  100. return [message for message in self.__message_snapshots.values()]
  101. def __text_deltas__(self) -> Iterator[str]:
  102. for event in self:
  103. if event.event != "thread.message.delta":
  104. continue
  105. for content_delta in event.data.delta.content or []:
  106. if content_delta.type == "text" and content_delta.text and content_delta.text.value:
  107. yield content_delta.text.value
  108. # event handlers
  109. def on_end(self) -> None:
  110. """Fires when the stream has finished.
  111. This happens if the stream is read to completion
  112. or if an exception occurs during iteration.
  113. """
  114. def on_event(self, event: AssistantStreamEvent) -> None:
  115. """Callback that is fired for every Server-Sent-Event"""
  116. def on_run_step_created(self, run_step: RunStep) -> None:
  117. """Callback that is fired when a run step is created"""
  118. def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
  119. """Callback that is fired whenever a run step delta is returned from the API
  120. The first argument is just the delta as sent by the API and the second argument
  121. is the accumulated snapshot of the run step. For example, a tool calls event may
  122. look like this:
  123. # delta
  124. tool_calls=[
  125. RunStepDeltaToolCallsCodeInterpreter(
  126. index=0,
  127. type='code_interpreter',
  128. id=None,
  129. code_interpreter=CodeInterpreter(input=' sympy', outputs=None)
  130. )
  131. ]
  132. # snapshot
  133. tool_calls=[
  134. CodeToolCall(
  135. id='call_wKayJlcYV12NiadiZuJXxcfx',
  136. code_interpreter=CodeInterpreter(input='from sympy', outputs=[]),
  137. type='code_interpreter',
  138. index=0
  139. )
  140. ],
  141. """
  142. def on_run_step_done(self, run_step: RunStep) -> None:
  143. """Callback that is fired when a run step is completed"""
  144. def on_tool_call_created(self, tool_call: ToolCall) -> None:
  145. """Callback that is fired when a tool call is created"""
  146. def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
  147. """Callback that is fired when a tool call delta is encountered"""
  148. def on_tool_call_done(self, tool_call: ToolCall) -> None:
  149. """Callback that is fired when a tool call delta is encountered"""
  150. def on_exception(self, exception: Exception) -> None:
  151. """Fired whenever an exception happens during streaming"""
  152. def on_timeout(self) -> None:
  153. """Fires if the request times out"""
  154. def on_message_created(self, message: Message) -> None:
  155. """Callback that is fired when a message is created"""
  156. def on_message_delta(self, delta: MessageDelta, snapshot: Message) -> None:
  157. """Callback that is fired whenever a message delta is returned from the API
  158. The first argument is just the delta as sent by the API and the second argument
  159. is the accumulated snapshot of the message. For example, a text content event may
  160. look like this:
  161. # delta
  162. MessageDeltaText(
  163. index=0,
  164. type='text',
  165. text=Text(
  166. value=' Jane'
  167. ),
  168. )
  169. # snapshot
  170. MessageContentText(
  171. index=0,
  172. type='text',
  173. text=Text(
  174. value='Certainly, Jane'
  175. ),
  176. )
  177. """
  178. def on_message_done(self, message: Message) -> None:
  179. """Callback that is fired when a message is completed"""
  180. def on_text_created(self, text: Text) -> None:
  181. """Callback that is fired when a text content block is created"""
  182. def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
  183. """Callback that is fired whenever a text content delta is returned
  184. by the API.
  185. The first argument is just the delta as sent by the API and the second argument
  186. is the accumulated snapshot of the text. For example:
  187. on_text_delta(TextDelta(value="The"), Text(value="The")),
  188. on_text_delta(TextDelta(value=" solution"), Text(value="The solution")),
  189. on_text_delta(TextDelta(value=" to"), Text(value="The solution to")),
  190. on_text_delta(TextDelta(value=" the"), Text(value="The solution to the")),
  191. on_text_delta(TextDelta(value=" equation"), Text(value="The solution to the equation")),
  192. """
  193. def on_text_done(self, text: Text) -> None:
  194. """Callback that is fired when a text content block is finished"""
  195. def on_image_file_done(self, image_file: ImageFile) -> None:
  196. """Callback that is fired when an image file block is finished"""
  197. def _emit_sse_event(self, event: AssistantStreamEvent) -> None:
  198. self._current_event = event
  199. self.on_event(event)
  200. self.__current_message_snapshot, new_content = accumulate_event(
  201. event=event,
  202. current_message_snapshot=self.__current_message_snapshot,
  203. )
  204. if self.__current_message_snapshot is not None:
  205. self.__message_snapshots[self.__current_message_snapshot.id] = self.__current_message_snapshot
  206. accumulate_run_step(
  207. event=event,
  208. run_step_snapshots=self.__run_step_snapshots,
  209. )
  210. for content_delta in new_content:
  211. assert self.__current_message_snapshot is not None
  212. block = self.__current_message_snapshot.content[content_delta.index]
  213. if block.type == "text":
  214. self.on_text_created(block.text)
  215. if (
  216. event.event == "thread.run.completed"
  217. or event.event == "thread.run.cancelled"
  218. or event.event == "thread.run.expired"
  219. or event.event == "thread.run.failed"
  220. or event.event == "thread.run.requires_action"
  221. or event.event == "thread.run.incomplete"
  222. ):
  223. self.__current_run = event.data
  224. if self._current_tool_call:
  225. self.on_tool_call_done(self._current_tool_call)
  226. elif (
  227. event.event == "thread.run.created"
  228. or event.event == "thread.run.in_progress"
  229. or event.event == "thread.run.cancelling"
  230. or event.event == "thread.run.queued"
  231. ):
  232. self.__current_run = event.data
  233. elif event.event == "thread.message.created":
  234. self.on_message_created(event.data)
  235. elif event.event == "thread.message.delta":
  236. snapshot = self.__current_message_snapshot
  237. assert snapshot is not None
  238. message_delta = event.data.delta
  239. if message_delta.content is not None:
  240. for content_delta in message_delta.content:
  241. if content_delta.type == "text" and content_delta.text:
  242. snapshot_content = snapshot.content[content_delta.index]
  243. assert snapshot_content.type == "text"
  244. self.on_text_delta(content_delta.text, snapshot_content.text)
  245. # If the delta is for a new message content:
  246. # - emit on_text_done/on_image_file_done for the previous message content
  247. # - emit on_text_created/on_image_created for the new message content
  248. if content_delta.index != self._current_message_content_index:
  249. if self._current_message_content is not None:
  250. if self._current_message_content.type == "text":
  251. self.on_text_done(self._current_message_content.text)
  252. elif self._current_message_content.type == "image_file":
  253. self.on_image_file_done(self._current_message_content.image_file)
  254. self._current_message_content_index = content_delta.index
  255. self._current_message_content = snapshot.content[content_delta.index]
  256. # Update the current_message_content (delta event is correctly emitted already)
  257. self._current_message_content = snapshot.content[content_delta.index]
  258. self.on_message_delta(event.data.delta, snapshot)
  259. elif event.event == "thread.message.completed" or event.event == "thread.message.incomplete":
  260. self.__current_message_snapshot = event.data
  261. self.__message_snapshots[event.data.id] = event.data
  262. if self._current_message_content_index is not None:
  263. content = event.data.content[self._current_message_content_index]
  264. if content.type == "text":
  265. self.on_text_done(content.text)
  266. elif content.type == "image_file":
  267. self.on_image_file_done(content.image_file)
  268. self.on_message_done(event.data)
  269. elif event.event == "thread.run.step.created":
  270. self.__current_run_step_id = event.data.id
  271. self.on_run_step_created(event.data)
  272. elif event.event == "thread.run.step.in_progress":
  273. self.__current_run_step_id = event.data.id
  274. elif event.event == "thread.run.step.delta":
  275. step_snapshot = self.__run_step_snapshots[event.data.id]
  276. run_step_delta = event.data.delta
  277. if (
  278. run_step_delta.step_details
  279. and run_step_delta.step_details.type == "tool_calls"
  280. and run_step_delta.step_details.tool_calls is not None
  281. ):
  282. assert step_snapshot.step_details.type == "tool_calls"
  283. for tool_call_delta in run_step_delta.step_details.tool_calls:
  284. if tool_call_delta.index == self._current_tool_call_index:
  285. self.on_tool_call_delta(
  286. tool_call_delta,
  287. step_snapshot.step_details.tool_calls[tool_call_delta.index],
  288. )
  289. # If the delta is for a new tool call:
  290. # - emit on_tool_call_done for the previous tool_call
  291. # - emit on_tool_call_created for the new tool_call
  292. if tool_call_delta.index != self._current_tool_call_index:
  293. if self._current_tool_call is not None:
  294. self.on_tool_call_done(self._current_tool_call)
  295. self._current_tool_call_index = tool_call_delta.index
  296. self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
  297. self.on_tool_call_created(self._current_tool_call)
  298. # Update the current_tool_call (delta event is correctly emitted already)
  299. self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
  300. self.on_run_step_delta(
  301. event.data.delta,
  302. step_snapshot,
  303. )
  304. elif (
  305. event.event == "thread.run.step.completed"
  306. or event.event == "thread.run.step.cancelled"
  307. or event.event == "thread.run.step.expired"
  308. or event.event == "thread.run.step.failed"
  309. ):
  310. if self._current_tool_call:
  311. self.on_tool_call_done(self._current_tool_call)
  312. self.on_run_step_done(event.data)
  313. self.__current_run_step_id = None
  314. elif event.event == "thread.created" or event.event == "thread.message.in_progress" or event.event == "error":
  315. # currently no special handling
  316. ...
  317. else:
  318. # we only want to error at build-time
  319. if TYPE_CHECKING: # type: ignore[unreachable]
  320. assert_never(event)
  321. self._current_event = None
  322. def __stream__(self) -> Iterator[AssistantStreamEvent]:
  323. stream = self.__stream
  324. if not stream:
  325. raise RuntimeError("Stream has not been started yet")
  326. try:
  327. for event in stream:
  328. self._emit_sse_event(event)
  329. yield event
  330. except (httpx.TimeoutException, asyncio.TimeoutError) as exc:
  331. self.on_timeout()
  332. self.on_exception(exc)
  333. raise
  334. except Exception as exc:
  335. self.on_exception(exc)
  336. raise
  337. finally:
  338. self.on_end()
  339. AssistantEventHandlerT = TypeVar("AssistantEventHandlerT", bound=AssistantEventHandler)
  340. class AssistantStreamManager(Generic[AssistantEventHandlerT]):
  341. """Wrapper over AssistantStreamEventHandler that is returned by `.stream()`
  342. so that a context manager can be used.
  343. ```py
  344. with client.threads.create_and_run_stream(...) as stream:
  345. for event in stream:
  346. ...
  347. ```
  348. """
  349. def __init__(
  350. self,
  351. api_request: Callable[[], Stream[AssistantStreamEvent]],
  352. *,
  353. event_handler: AssistantEventHandlerT,
  354. ) -> None:
  355. self.__stream: Stream[AssistantStreamEvent] | None = None
  356. self.__event_handler = event_handler
  357. self.__api_request = api_request
  358. def __enter__(self) -> AssistantEventHandlerT:
  359. self.__stream = self.__api_request()
  360. self.__event_handler._init(self.__stream)
  361. return self.__event_handler
  362. def __exit__(
  363. self,
  364. exc_type: type[BaseException] | None,
  365. exc: BaseException | None,
  366. exc_tb: TracebackType | None,
  367. ) -> None:
  368. if self.__stream is not None:
  369. self.__stream.close()
  370. class AsyncAssistantEventHandler:
  371. text_deltas: AsyncIterable[str]
  372. """Iterator over just the text deltas in the stream.
  373. This corresponds to the `thread.message.delta` event
  374. in the API.
  375. ```py
  376. async for text in stream.text_deltas:
  377. print(text, end="", flush=True)
  378. print()
  379. ```
  380. """
  381. def __init__(self) -> None:
  382. self._current_event: AssistantStreamEvent | None = None
  383. self._current_message_content_index: int | None = None
  384. self._current_message_content: MessageContent | None = None
  385. self._current_tool_call_index: int | None = None
  386. self._current_tool_call: ToolCall | None = None
  387. self.__current_run_step_id: str | None = None
  388. self.__current_run: Run | None = None
  389. self.__run_step_snapshots: dict[str, RunStep] = {}
  390. self.__message_snapshots: dict[str, Message] = {}
  391. self.__current_message_snapshot: Message | None = None
  392. self.text_deltas = self.__text_deltas__()
  393. self._iterator = self.__stream__()
  394. self.__stream: AsyncStream[AssistantStreamEvent] | None = None
  395. def _init(self, stream: AsyncStream[AssistantStreamEvent]) -> None:
  396. if self.__stream:
  397. raise RuntimeError(
  398. "A single event handler cannot be shared between multiple streams; You will need to construct a new event handler instance"
  399. )
  400. self.__stream = stream
  401. async def __anext__(self) -> AssistantStreamEvent:
  402. return await self._iterator.__anext__()
  403. async def __aiter__(self) -> AsyncIterator[AssistantStreamEvent]:
  404. async for item in self._iterator:
  405. yield item
  406. async def close(self) -> None:
  407. """
  408. Close the response and release the connection.
  409. Automatically called when the context manager exits.
  410. """
  411. if self.__stream:
  412. await self.__stream.close()
  413. @property
  414. def current_event(self) -> AssistantStreamEvent | None:
  415. return self._current_event
  416. @property
  417. def current_run(self) -> Run | None:
  418. return self.__current_run
  419. @property
  420. def current_run_step_snapshot(self) -> RunStep | None:
  421. if not self.__current_run_step_id:
  422. return None
  423. return self.__run_step_snapshots[self.__current_run_step_id]
  424. @property
  425. def current_message_snapshot(self) -> Message | None:
  426. return self.__current_message_snapshot
  427. async def until_done(self) -> None:
  428. """Waits until the stream has been consumed"""
  429. await consume_async_iterator(self)
  430. async def get_final_run(self) -> Run:
  431. """Wait for the stream to finish and returns the completed Run object"""
  432. await self.until_done()
  433. if not self.__current_run:
  434. raise RuntimeError("No final run object found")
  435. return self.__current_run
  436. async def get_final_run_steps(self) -> list[RunStep]:
  437. """Wait for the stream to finish and returns the steps taken in this run"""
  438. await self.until_done()
  439. if not self.__run_step_snapshots:
  440. raise RuntimeError("No run steps found")
  441. return [step for step in self.__run_step_snapshots.values()]
  442. async def get_final_messages(self) -> list[Message]:
  443. """Wait for the stream to finish and returns the messages emitted in this run"""
  444. await self.until_done()
  445. if not self.__message_snapshots:
  446. raise RuntimeError("No messages found")
  447. return [message for message in self.__message_snapshots.values()]
  448. async def __text_deltas__(self) -> AsyncIterator[str]:
  449. async for event in self:
  450. if event.event != "thread.message.delta":
  451. continue
  452. for content_delta in event.data.delta.content or []:
  453. if content_delta.type == "text" and content_delta.text and content_delta.text.value:
  454. yield content_delta.text.value
  455. # event handlers
  456. async def on_end(self) -> None:
  457. """Fires when the stream has finished.
  458. This happens if the stream is read to completion
  459. or if an exception occurs during iteration.
  460. """
  461. async def on_event(self, event: AssistantStreamEvent) -> None:
  462. """Callback that is fired for every Server-Sent-Event"""
  463. async def on_run_step_created(self, run_step: RunStep) -> None:
  464. """Callback that is fired when a run step is created"""
  465. async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None:
  466. """Callback that is fired whenever a run step delta is returned from the API
  467. The first argument is just the delta as sent by the API and the second argument
  468. is the accumulated snapshot of the run step. For example, a tool calls event may
  469. look like this:
  470. # delta
  471. tool_calls=[
  472. RunStepDeltaToolCallsCodeInterpreter(
  473. index=0,
  474. type='code_interpreter',
  475. id=None,
  476. code_interpreter=CodeInterpreter(input=' sympy', outputs=None)
  477. )
  478. ]
  479. # snapshot
  480. tool_calls=[
  481. CodeToolCall(
  482. id='call_wKayJlcYV12NiadiZuJXxcfx',
  483. code_interpreter=CodeInterpreter(input='from sympy', outputs=[]),
  484. type='code_interpreter',
  485. index=0
  486. )
  487. ],
  488. """
  489. async def on_run_step_done(self, run_step: RunStep) -> None:
  490. """Callback that is fired when a run step is completed"""
  491. async def on_tool_call_created(self, tool_call: ToolCall) -> None:
  492. """Callback that is fired when a tool call is created"""
  493. async def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
  494. """Callback that is fired when a tool call delta is encountered"""
  495. async def on_tool_call_done(self, tool_call: ToolCall) -> None:
  496. """Callback that is fired when a tool call delta is encountered"""
  497. async def on_exception(self, exception: Exception) -> None:
  498. """Fired whenever an exception happens during streaming"""
  499. async def on_timeout(self) -> None:
  500. """Fires if the request times out"""
  501. async def on_message_created(self, message: Message) -> None:
  502. """Callback that is fired when a message is created"""
  503. async def on_message_delta(self, delta: MessageDelta, snapshot: Message) -> None:
  504. """Callback that is fired whenever a message delta is returned from the API
  505. The first argument is just the delta as sent by the API and the second argument
  506. is the accumulated snapshot of the message. For example, a text content event may
  507. look like this:
  508. # delta
  509. MessageDeltaText(
  510. index=0,
  511. type='text',
  512. text=Text(
  513. value=' Jane'
  514. ),
  515. )
  516. # snapshot
  517. MessageContentText(
  518. index=0,
  519. type='text',
  520. text=Text(
  521. value='Certainly, Jane'
  522. ),
  523. )
  524. """
  525. async def on_message_done(self, message: Message) -> None:
  526. """Callback that is fired when a message is completed"""
  527. async def on_text_created(self, text: Text) -> None:
  528. """Callback that is fired when a text content block is created"""
  529. async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
  530. """Callback that is fired whenever a text content delta is returned
  531. by the API.
  532. The first argument is just the delta as sent by the API and the second argument
  533. is the accumulated snapshot of the text. For example:
  534. on_text_delta(TextDelta(value="The"), Text(value="The")),
  535. on_text_delta(TextDelta(value=" solution"), Text(value="The solution")),
  536. on_text_delta(TextDelta(value=" to"), Text(value="The solution to")),
  537. on_text_delta(TextDelta(value=" the"), Text(value="The solution to the")),
  538. on_text_delta(TextDelta(value=" equation"), Text(value="The solution to the equivalent")),
  539. """
  540. async def on_text_done(self, text: Text) -> None:
  541. """Callback that is fired when a text content block is finished"""
  542. async def on_image_file_done(self, image_file: ImageFile) -> None:
  543. """Callback that is fired when an image file block is finished"""
  544. async def _emit_sse_event(self, event: AssistantStreamEvent) -> None:
  545. self._current_event = event
  546. await self.on_event(event)
  547. self.__current_message_snapshot, new_content = accumulate_event(
  548. event=event,
  549. current_message_snapshot=self.__current_message_snapshot,
  550. )
  551. if self.__current_message_snapshot is not None:
  552. self.__message_snapshots[self.__current_message_snapshot.id] = self.__current_message_snapshot
  553. accumulate_run_step(
  554. event=event,
  555. run_step_snapshots=self.__run_step_snapshots,
  556. )
  557. for content_delta in new_content:
  558. assert self.__current_message_snapshot is not None
  559. block = self.__current_message_snapshot.content[content_delta.index]
  560. if block.type == "text":
  561. await self.on_text_created(block.text)
  562. if (
  563. event.event == "thread.run.completed"
  564. or event.event == "thread.run.cancelled"
  565. or event.event == "thread.run.expired"
  566. or event.event == "thread.run.failed"
  567. or event.event == "thread.run.requires_action"
  568. or event.event == "thread.run.incomplete"
  569. ):
  570. self.__current_run = event.data
  571. if self._current_tool_call:
  572. await self.on_tool_call_done(self._current_tool_call)
  573. elif (
  574. event.event == "thread.run.created"
  575. or event.event == "thread.run.in_progress"
  576. or event.event == "thread.run.cancelling"
  577. or event.event == "thread.run.queued"
  578. ):
  579. self.__current_run = event.data
  580. elif event.event == "thread.message.created":
  581. await self.on_message_created(event.data)
  582. elif event.event == "thread.message.delta":
  583. snapshot = self.__current_message_snapshot
  584. assert snapshot is not None
  585. message_delta = event.data.delta
  586. if message_delta.content is not None:
  587. for content_delta in message_delta.content:
  588. if content_delta.type == "text" and content_delta.text:
  589. snapshot_content = snapshot.content[content_delta.index]
  590. assert snapshot_content.type == "text"
  591. await self.on_text_delta(content_delta.text, snapshot_content.text)
  592. # If the delta is for a new message content:
  593. # - emit on_text_done/on_image_file_done for the previous message content
  594. # - emit on_text_created/on_image_created for the new message content
  595. if content_delta.index != self._current_message_content_index:
  596. if self._current_message_content is not None:
  597. if self._current_message_content.type == "text":
  598. await self.on_text_done(self._current_message_content.text)
  599. elif self._current_message_content.type == "image_file":
  600. await self.on_image_file_done(self._current_message_content.image_file)
  601. self._current_message_content_index = content_delta.index
  602. self._current_message_content = snapshot.content[content_delta.index]
  603. # Update the current_message_content (delta event is correctly emitted already)
  604. self._current_message_content = snapshot.content[content_delta.index]
  605. await self.on_message_delta(event.data.delta, snapshot)
  606. elif event.event == "thread.message.completed" or event.event == "thread.message.incomplete":
  607. self.__current_message_snapshot = event.data
  608. self.__message_snapshots[event.data.id] = event.data
  609. if self._current_message_content_index is not None:
  610. content = event.data.content[self._current_message_content_index]
  611. if content.type == "text":
  612. await self.on_text_done(content.text)
  613. elif content.type == "image_file":
  614. await self.on_image_file_done(content.image_file)
  615. await self.on_message_done(event.data)
  616. elif event.event == "thread.run.step.created":
  617. self.__current_run_step_id = event.data.id
  618. await self.on_run_step_created(event.data)
  619. elif event.event == "thread.run.step.in_progress":
  620. self.__current_run_step_id = event.data.id
  621. elif event.event == "thread.run.step.delta":
  622. step_snapshot = self.__run_step_snapshots[event.data.id]
  623. run_step_delta = event.data.delta
  624. if (
  625. run_step_delta.step_details
  626. and run_step_delta.step_details.type == "tool_calls"
  627. and run_step_delta.step_details.tool_calls is not None
  628. ):
  629. assert step_snapshot.step_details.type == "tool_calls"
  630. for tool_call_delta in run_step_delta.step_details.tool_calls:
  631. if tool_call_delta.index == self._current_tool_call_index:
  632. await self.on_tool_call_delta(
  633. tool_call_delta,
  634. step_snapshot.step_details.tool_calls[tool_call_delta.index],
  635. )
  636. # If the delta is for a new tool call:
  637. # - emit on_tool_call_done for the previous tool_call
  638. # - emit on_tool_call_created for the new tool_call
  639. if tool_call_delta.index != self._current_tool_call_index:
  640. if self._current_tool_call is not None:
  641. await self.on_tool_call_done(self._current_tool_call)
  642. self._current_tool_call_index = tool_call_delta.index
  643. self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
  644. await self.on_tool_call_created(self._current_tool_call)
  645. # Update the current_tool_call (delta event is correctly emitted already)
  646. self._current_tool_call = step_snapshot.step_details.tool_calls[tool_call_delta.index]
  647. await self.on_run_step_delta(
  648. event.data.delta,
  649. step_snapshot,
  650. )
  651. elif (
  652. event.event == "thread.run.step.completed"
  653. or event.event == "thread.run.step.cancelled"
  654. or event.event == "thread.run.step.expired"
  655. or event.event == "thread.run.step.failed"
  656. ):
  657. if self._current_tool_call:
  658. await self.on_tool_call_done(self._current_tool_call)
  659. await self.on_run_step_done(event.data)
  660. self.__current_run_step_id = None
  661. elif event.event == "thread.created" or event.event == "thread.message.in_progress" or event.event == "error":
  662. # currently no special handling
  663. ...
  664. else:
  665. # we only want to error at build-time
  666. if TYPE_CHECKING: # type: ignore[unreachable]
  667. assert_never(event)
  668. self._current_event = None
  669. async def __stream__(self) -> AsyncIterator[AssistantStreamEvent]:
  670. stream = self.__stream
  671. if not stream:
  672. raise RuntimeError("Stream has not been started yet")
  673. try:
  674. async for event in stream:
  675. await self._emit_sse_event(event)
  676. yield event
  677. except (httpx.TimeoutException, asyncio.TimeoutError) as exc:
  678. await self.on_timeout()
  679. await self.on_exception(exc)
  680. raise
  681. except Exception as exc:
  682. await self.on_exception(exc)
  683. raise
  684. finally:
  685. await self.on_end()
  686. AsyncAssistantEventHandlerT = TypeVar("AsyncAssistantEventHandlerT", bound=AsyncAssistantEventHandler)
  687. class AsyncAssistantStreamManager(Generic[AsyncAssistantEventHandlerT]):
  688. """Wrapper over AsyncAssistantStreamEventHandler that is returned by `.stream()`
  689. so that an async context manager can be used without `await`ing the
  690. original client call.
  691. ```py
  692. async with client.threads.create_and_run_stream(...) as stream:
  693. async for event in stream:
  694. ...
  695. ```
  696. """
  697. def __init__(
  698. self,
  699. api_request: Awaitable[AsyncStream[AssistantStreamEvent]],
  700. *,
  701. event_handler: AsyncAssistantEventHandlerT,
  702. ) -> None:
  703. self.__stream: AsyncStream[AssistantStreamEvent] | None = None
  704. self.__event_handler = event_handler
  705. self.__api_request = api_request
  706. async def __aenter__(self) -> AsyncAssistantEventHandlerT:
  707. self.__stream = await self.__api_request
  708. self.__event_handler._init(self.__stream)
  709. return self.__event_handler
  710. async def __aexit__(
  711. self,
  712. exc_type: type[BaseException] | None,
  713. exc: BaseException | None,
  714. exc_tb: TracebackType | None,
  715. ) -> None:
  716. if self.__stream is not None:
  717. await self.__stream.close()
  718. def accumulate_run_step(
  719. *,
  720. event: AssistantStreamEvent,
  721. run_step_snapshots: dict[str, RunStep],
  722. ) -> None:
  723. if event.event == "thread.run.step.created":
  724. run_step_snapshots[event.data.id] = event.data
  725. return
  726. if event.event == "thread.run.step.delta":
  727. data = event.data
  728. snapshot = run_step_snapshots[data.id]
  729. if data.delta:
  730. merged = accumulate_delta(
  731. cast(
  732. "dict[object, object]",
  733. model_dump(snapshot, exclude_unset=True, warnings=False),
  734. ),
  735. cast(
  736. "dict[object, object]",
  737. model_dump(data.delta, exclude_unset=True, warnings=False),
  738. ),
  739. )
  740. run_step_snapshots[snapshot.id] = cast(RunStep, construct_type(type_=RunStep, value=merged))
  741. return None
  742. def accumulate_event(
  743. *,
  744. event: AssistantStreamEvent,
  745. current_message_snapshot: Message | None,
  746. ) -> tuple[Message | None, list[MessageContentDelta]]:
  747. """Returns a tuple of message snapshot and newly created text message deltas"""
  748. if event.event == "thread.message.created":
  749. return event.data, []
  750. new_content: list[MessageContentDelta] = []
  751. if event.event != "thread.message.delta":
  752. return current_message_snapshot, []
  753. if not current_message_snapshot:
  754. raise RuntimeError("Encountered a message delta with no previous snapshot")
  755. data = event.data
  756. if data.delta.content:
  757. for content_delta in data.delta.content:
  758. try:
  759. block = current_message_snapshot.content[content_delta.index]
  760. except IndexError:
  761. current_message_snapshot.content.insert(
  762. content_delta.index,
  763. cast(
  764. MessageContent,
  765. construct_type(
  766. # mypy doesn't allow Content for some reason
  767. type_=cast(Any, MessageContent),
  768. value=model_dump(content_delta, exclude_unset=True, warnings=False),
  769. ),
  770. ),
  771. )
  772. new_content.append(content_delta)
  773. else:
  774. merged = accumulate_delta(
  775. cast(
  776. "dict[object, object]",
  777. model_dump(block, exclude_unset=True, warnings=False),
  778. ),
  779. cast(
  780. "dict[object, object]",
  781. model_dump(content_delta, exclude_unset=True, warnings=False),
  782. ),
  783. )
  784. current_message_snapshot.content[content_delta.index] = cast(
  785. MessageContent,
  786. construct_type(
  787. # mypy doesn't allow Content for some reason
  788. type_=cast(Any, MessageContent),
  789. value=merged,
  790. ),
  791. )
  792. return current_message_snapshot, new_content
  793. def accumulate_delta(acc: dict[object, object], delta: dict[object, object]) -> dict[object, object]:
  794. for key, delta_value in delta.items():
  795. if key not in acc:
  796. acc[key] = delta_value
  797. continue
  798. acc_value = acc[key]
  799. if acc_value is None:
  800. acc[key] = delta_value
  801. continue
  802. # the `index` property is used in arrays of objects so it should
  803. # not be accumulated like other values e.g.
  804. # [{'foo': 'bar', 'index': 0}]
  805. #
  806. # the same applies to `type` properties as they're used for
  807. # discriminated unions
  808. if key == "index" or key == "type":
  809. acc[key] = delta_value
  810. continue
  811. if isinstance(acc_value, str) and isinstance(delta_value, str):
  812. acc_value += delta_value
  813. elif isinstance(acc_value, (int, float)) and isinstance(delta_value, (int, float)):
  814. acc_value += delta_value
  815. elif is_dict(acc_value) and is_dict(delta_value):
  816. acc_value = accumulate_delta(acc_value, delta_value)
  817. elif is_list(acc_value) and is_list(delta_value):
  818. # for lists of non-dictionary items we'll only ever get new entries
  819. # in the array, existing entries will never be changed
  820. if all(isinstance(x, (str, int, float)) for x in acc_value):
  821. acc_value.extend(delta_value)
  822. continue
  823. for delta_entry in delta_value:
  824. if not is_dict(delta_entry):
  825. raise TypeError(f"Unexpected list delta entry is not a dictionary: {delta_entry}")
  826. try:
  827. index = delta_entry["index"]
  828. except KeyError as exc:
  829. raise RuntimeError(f"Expected list delta entry to have an `index` key; {delta_entry}") from exc
  830. if not isinstance(index, int):
  831. raise TypeError(f"Unexpected, list delta entry `index` value is not an integer; {index}")
  832. try:
  833. acc_entry = acc_value[index]
  834. except IndexError:
  835. acc_value.insert(index, delta_entry)
  836. else:
  837. if not is_dict(acc_entry):
  838. raise TypeError("not handled yet")
  839. acc_value[index] = accumulate_delta(acc_entry, delta_entry)
  840. acc[key] = acc_value
  841. return acc