_responses.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. from __future__ import annotations
  2. import json
  3. from typing import TYPE_CHECKING, Any, List, Iterable, cast
  4. from typing_extensions import TypeVar, assert_never
  5. import pydantic
  6. from .._tools import ResponsesPydanticFunctionTool
  7. from ..._types import Omit
  8. from ..._utils import is_given
  9. from ..._compat import PYDANTIC_V1, model_parse_json
  10. from ..._models import construct_type_unchecked
  11. from .._pydantic import is_basemodel_type, is_dataclass_like_type
  12. from ._completions import solve_response_format_t, type_to_response_format_param
  13. from ...types.responses import (
  14. Response,
  15. ToolParam,
  16. ParsedContent,
  17. ParsedResponse,
  18. FunctionToolParam,
  19. ParsedResponseOutputItem,
  20. ParsedResponseOutputText,
  21. ResponseFunctionToolCall,
  22. ParsedResponseOutputMessage,
  23. ResponseFormatTextConfigParam,
  24. ParsedResponseFunctionToolCall,
  25. )
  26. from ...types.chat.completion_create_params import ResponseFormat
  27. TextFormatT = TypeVar(
  28. "TextFormatT",
  29. # if it isn't given then we don't do any parsing
  30. default=None,
  31. )
  32. def type_to_text_format_param(type_: type) -> ResponseFormatTextConfigParam:
  33. response_format_dict = type_to_response_format_param(type_)
  34. assert is_given(response_format_dict)
  35. response_format_dict = cast(ResponseFormat, response_format_dict) # pyright: ignore[reportUnnecessaryCast]
  36. assert response_format_dict["type"] == "json_schema"
  37. assert "schema" in response_format_dict["json_schema"]
  38. return {
  39. "type": "json_schema",
  40. "strict": True,
  41. "name": response_format_dict["json_schema"]["name"],
  42. "schema": response_format_dict["json_schema"]["schema"],
  43. }
  44. def parse_response(
  45. *,
  46. text_format: type[TextFormatT] | Omit,
  47. input_tools: Iterable[ToolParam] | Omit | None,
  48. response: Response | ParsedResponse[object],
  49. ) -> ParsedResponse[TextFormatT]:
  50. solved_t = solve_response_format_t(text_format)
  51. output_list: List[ParsedResponseOutputItem[TextFormatT]] = []
  52. for output in response.output:
  53. if output.type == "message":
  54. content_list: List[ParsedContent[TextFormatT]] = []
  55. for item in output.content:
  56. if item.type != "output_text":
  57. content_list.append(item)
  58. continue
  59. content_list.append(
  60. construct_type_unchecked(
  61. type_=cast(Any, ParsedResponseOutputText)[solved_t],
  62. value={
  63. **item.to_dict(),
  64. "parsed": parse_text(item.text, text_format=text_format),
  65. },
  66. )
  67. )
  68. output_list.append(
  69. construct_type_unchecked(
  70. type_=cast(Any, ParsedResponseOutputMessage)[solved_t],
  71. value={
  72. **output.to_dict(),
  73. "content": content_list,
  74. },
  75. )
  76. )
  77. elif output.type == "function_call":
  78. output_list.append(
  79. construct_type_unchecked(
  80. type_=ParsedResponseFunctionToolCall,
  81. value={
  82. **output.to_dict(),
  83. "parsed_arguments": parse_function_tool_arguments(
  84. input_tools=input_tools, function_call=output
  85. ),
  86. },
  87. )
  88. )
  89. elif (
  90. output.type == "computer_call"
  91. or output.type == "file_search_call"
  92. or output.type == "web_search_call"
  93. or output.type == "reasoning"
  94. or output.type == "compaction"
  95. or output.type == "mcp_call"
  96. or output.type == "mcp_approval_request"
  97. or output.type == "image_generation_call"
  98. or output.type == "code_interpreter_call"
  99. or output.type == "local_shell_call"
  100. or output.type == "shell_call"
  101. or output.type == "shell_call_output"
  102. or output.type == "apply_patch_call"
  103. or output.type == "apply_patch_call_output"
  104. or output.type == "mcp_list_tools"
  105. or output.type == "exec"
  106. or output.type == "custom_tool_call"
  107. ):
  108. output_list.append(output)
  109. elif TYPE_CHECKING: # type: ignore
  110. assert_never(output)
  111. else:
  112. output_list.append(output)
  113. return cast(
  114. ParsedResponse[TextFormatT],
  115. construct_type_unchecked(
  116. type_=cast(Any, ParsedResponse)[solved_t],
  117. value={
  118. **response.to_dict(),
  119. "output": output_list,
  120. },
  121. ),
  122. )
  123. def parse_text(text: str, text_format: type[TextFormatT] | Omit) -> TextFormatT | None:
  124. if not is_given(text_format):
  125. return None
  126. if is_basemodel_type(text_format):
  127. return cast(TextFormatT, model_parse_json(text_format, text))
  128. if is_dataclass_like_type(text_format):
  129. if PYDANTIC_V1:
  130. raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {text_format}")
  131. return pydantic.TypeAdapter(text_format).validate_json(text)
  132. raise TypeError(f"Unable to automatically parse response format type {text_format}")
  133. def get_input_tool_by_name(*, input_tools: Iterable[ToolParam], name: str) -> FunctionToolParam | None:
  134. for tool in input_tools:
  135. if tool["type"] == "function" and tool.get("name") == name:
  136. return tool
  137. return None
  138. def parse_function_tool_arguments(
  139. *,
  140. input_tools: Iterable[ToolParam] | Omit | None,
  141. function_call: ParsedResponseFunctionToolCall | ResponseFunctionToolCall,
  142. ) -> object:
  143. if input_tools is None or not is_given(input_tools):
  144. return None
  145. input_tool = get_input_tool_by_name(input_tools=input_tools, name=function_call.name)
  146. if not input_tool:
  147. return None
  148. tool = cast(object, input_tool)
  149. if isinstance(tool, ResponsesPydanticFunctionTool):
  150. return model_parse_json(tool.model, function_call.arguments)
  151. if not input_tool.get("strict"):
  152. return None
  153. return json.loads(function_call.arguments)