| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- from __future__ import annotations
- import json
- from typing import TYPE_CHECKING, Any, List, Iterable, cast
- from typing_extensions import TypeVar, assert_never
- import pydantic
- from .._tools import ResponsesPydanticFunctionTool
- from ..._types import Omit
- from ..._utils import is_given
- from ..._compat import PYDANTIC_V1, model_parse_json
- from ..._models import construct_type_unchecked
- from .._pydantic import is_basemodel_type, is_dataclass_like_type
- from ._completions import solve_response_format_t, type_to_response_format_param
- from ...types.responses import (
- Response,
- ToolParam,
- ParsedContent,
- ParsedResponse,
- FunctionToolParam,
- ParsedResponseOutputItem,
- ParsedResponseOutputText,
- ResponseFunctionToolCall,
- ParsedResponseOutputMessage,
- ResponseFormatTextConfigParam,
- ParsedResponseFunctionToolCall,
- )
- from ...types.chat.completion_create_params import ResponseFormat
- TextFormatT = TypeVar(
- "TextFormatT",
- # if it isn't given then we don't do any parsing
- default=None,
- )
- def type_to_text_format_param(type_: type) -> ResponseFormatTextConfigParam:
- response_format_dict = type_to_response_format_param(type_)
- assert is_given(response_format_dict)
- response_format_dict = cast(ResponseFormat, response_format_dict) # pyright: ignore[reportUnnecessaryCast]
- assert response_format_dict["type"] == "json_schema"
- assert "schema" in response_format_dict["json_schema"]
- return {
- "type": "json_schema",
- "strict": True,
- "name": response_format_dict["json_schema"]["name"],
- "schema": response_format_dict["json_schema"]["schema"],
- }
- def parse_response(
- *,
- text_format: type[TextFormatT] | Omit,
- input_tools: Iterable[ToolParam] | Omit | None,
- response: Response | ParsedResponse[object],
- ) -> ParsedResponse[TextFormatT]:
- solved_t = solve_response_format_t(text_format)
- output_list: List[ParsedResponseOutputItem[TextFormatT]] = []
- for output in response.output:
- if output.type == "message":
- content_list: List[ParsedContent[TextFormatT]] = []
- for item in output.content:
- if item.type != "output_text":
- content_list.append(item)
- continue
- content_list.append(
- construct_type_unchecked(
- type_=cast(Any, ParsedResponseOutputText)[solved_t],
- value={
- **item.to_dict(),
- "parsed": parse_text(item.text, text_format=text_format),
- },
- )
- )
- output_list.append(
- construct_type_unchecked(
- type_=cast(Any, ParsedResponseOutputMessage)[solved_t],
- value={
- **output.to_dict(),
- "content": content_list,
- },
- )
- )
- elif output.type == "function_call":
- output_list.append(
- construct_type_unchecked(
- type_=ParsedResponseFunctionToolCall,
- value={
- **output.to_dict(),
- "parsed_arguments": parse_function_tool_arguments(
- input_tools=input_tools, function_call=output
- ),
- },
- )
- )
- elif (
- output.type == "computer_call"
- or output.type == "file_search_call"
- or output.type == "web_search_call"
- or output.type == "reasoning"
- or output.type == "compaction"
- or output.type == "mcp_call"
- or output.type == "mcp_approval_request"
- or output.type == "image_generation_call"
- or output.type == "code_interpreter_call"
- or output.type == "local_shell_call"
- or output.type == "shell_call"
- or output.type == "shell_call_output"
- or output.type == "apply_patch_call"
- or output.type == "apply_patch_call_output"
- or output.type == "mcp_list_tools"
- or output.type == "exec"
- or output.type == "custom_tool_call"
- ):
- output_list.append(output)
- elif TYPE_CHECKING: # type: ignore
- assert_never(output)
- else:
- output_list.append(output)
- return cast(
- ParsedResponse[TextFormatT],
- construct_type_unchecked(
- type_=cast(Any, ParsedResponse)[solved_t],
- value={
- **response.to_dict(),
- "output": output_list,
- },
- ),
- )
- def parse_text(text: str, text_format: type[TextFormatT] | Omit) -> TextFormatT | None:
- if not is_given(text_format):
- return None
- if is_basemodel_type(text_format):
- return cast(TextFormatT, model_parse_json(text_format, text))
- if is_dataclass_like_type(text_format):
- if PYDANTIC_V1:
- raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {text_format}")
- return pydantic.TypeAdapter(text_format).validate_json(text)
- raise TypeError(f"Unable to automatically parse response format type {text_format}")
- def get_input_tool_by_name(*, input_tools: Iterable[ToolParam], name: str) -> FunctionToolParam | None:
- for tool in input_tools:
- if tool["type"] == "function" and tool.get("name") == name:
- return tool
- return None
- def parse_function_tool_arguments(
- *,
- input_tools: Iterable[ToolParam] | Omit | None,
- function_call: ParsedResponseFunctionToolCall | ResponseFunctionToolCall,
- ) -> object:
- if input_tools is None or not is_given(input_tools):
- return None
- input_tool = get_input_tool_by_name(input_tools=input_tools, name=function_call.name)
- if not input_tool:
- return None
- tool = cast(object, input_tool)
- if isinstance(tool, ResponsesPydanticFunctionTool):
- return model_parse_json(tool.model, function_call.arguments)
- if not input_tool.get("strict"):
- return None
- return json.loads(function_call.arguments)
|