| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305 |
- from __future__ import annotations
- import json
- import logging
- from typing import TYPE_CHECKING, Any, Iterable, cast
- from typing_extensions import TypeVar, TypeGuard, assert_never
- import pydantic
- from .._tools import PydanticFunctionTool
- from ..._types import Omit, omit
- from ..._utils import is_dict, is_given
- from ..._compat import PYDANTIC_V1, model_parse_json
- from ..._models import construct_type_unchecked
- from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
- from ...types.chat import (
- ParsedChoice,
- ChatCompletion,
- ParsedFunction,
- ParsedChatCompletion,
- ChatCompletionMessage,
- ParsedFunctionToolCall,
- ParsedChatCompletionMessage,
- ChatCompletionToolUnionParam,
- ChatCompletionFunctionToolParam,
- completion_create_params,
- )
- from ..._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
- from ...types.shared_params import FunctionDefinition
- from ...types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
- from ...types.chat.chat_completion_message_function_tool_call import Function
- ResponseFormatT = TypeVar(
- "ResponseFormatT",
- # if it isn't given then we don't do any parsing
- default=None,
- )
- _default_response_format: None = None
- log: logging.Logger = logging.getLogger("openai.lib.parsing")
- def is_strict_chat_completion_tool_param(
- tool: ChatCompletionToolUnionParam,
- ) -> TypeGuard[ChatCompletionFunctionToolParam]:
- """Check if the given tool is a strict ChatCompletionFunctionToolParam."""
- if not tool["type"] == "function":
- return False
- if tool["function"].get("strict") is not True:
- return False
- return True
- def select_strict_chat_completion_tools(
- tools: Iterable[ChatCompletionToolUnionParam] | Omit = omit,
- ) -> Iterable[ChatCompletionFunctionToolParam] | Omit:
- """Select only the strict ChatCompletionFunctionToolParams from the given tools."""
- if not is_given(tools):
- return omit
- return [t for t in tools if is_strict_chat_completion_tool_param(t)]
- def validate_input_tools(
- tools: Iterable[ChatCompletionToolUnionParam] | Omit = omit,
- ) -> Iterable[ChatCompletionFunctionToolParam] | Omit:
- if not is_given(tools):
- return omit
- for tool in tools:
- if tool["type"] != "function":
- raise ValueError(
- f"Currently only `function` tool types support auto-parsing; Received `{tool['type']}`",
- )
- strict = tool["function"].get("strict")
- if strict is not True:
- raise ValueError(
- f"`{tool['function']['name']}` is not strict. Only `strict` function tools can be auto-parsed"
- )
- return cast(Iterable[ChatCompletionFunctionToolParam], tools)
- def parse_chat_completion(
- *,
- response_format: type[ResponseFormatT] | completion_create_params.ResponseFormat | Omit,
- input_tools: Iterable[ChatCompletionToolUnionParam] | Omit,
- chat_completion: ChatCompletion | ParsedChatCompletion[object],
- ) -> ParsedChatCompletion[ResponseFormatT]:
- if is_given(input_tools):
- input_tools = [t for t in input_tools]
- else:
- input_tools = []
- choices: list[ParsedChoice[ResponseFormatT]] = []
- for choice in chat_completion.choices:
- if choice.finish_reason == "length":
- raise LengthFinishReasonError(completion=chat_completion)
- if choice.finish_reason == "content_filter":
- raise ContentFilterFinishReasonError()
- message = choice.message
- tool_calls: list[ParsedFunctionToolCall] = []
- if message.tool_calls:
- for tool_call in message.tool_calls:
- if tool_call.type == "function":
- tool_call_dict = tool_call.to_dict()
- tool_calls.append(
- construct_type_unchecked(
- value={
- **tool_call_dict,
- "function": {
- **cast(Any, tool_call_dict["function"]),
- "parsed_arguments": parse_function_tool_arguments(
- input_tools=input_tools, function=tool_call.function
- ),
- },
- },
- type_=ParsedFunctionToolCall,
- )
- )
- elif tool_call.type == "custom":
- # warn user that custom tool calls are not callable here
- log.warning(
- "Custom tool calls are not callable. Ignoring tool call: %s - %s",
- tool_call.id,
- tool_call.custom.name,
- stacklevel=2,
- )
- elif TYPE_CHECKING: # type: ignore[unreachable]
- assert_never(tool_call)
- else:
- tool_calls.append(tool_call)
- choices.append(
- construct_type_unchecked(
- type_=cast(Any, ParsedChoice)[solve_response_format_t(response_format)],
- value={
- **choice.to_dict(),
- "message": {
- **message.to_dict(),
- "parsed": maybe_parse_content(
- response_format=response_format,
- message=message,
- ),
- "tool_calls": tool_calls if tool_calls else None,
- },
- },
- )
- )
- return cast(
- ParsedChatCompletion[ResponseFormatT],
- construct_type_unchecked(
- type_=cast(Any, ParsedChatCompletion)[solve_response_format_t(response_format)],
- value={
- **chat_completion.to_dict(),
- "choices": choices,
- },
- ),
- )
- def get_input_tool_by_name(
- *, input_tools: list[ChatCompletionToolUnionParam], name: str
- ) -> ChatCompletionFunctionToolParam | None:
- return next((t for t in input_tools if t["type"] == "function" and t.get("function", {}).get("name") == name), None)
- def parse_function_tool_arguments(
- *, input_tools: list[ChatCompletionToolUnionParam], function: Function | ParsedFunction
- ) -> object | None:
- input_tool = get_input_tool_by_name(input_tools=input_tools, name=function.name)
- if not input_tool:
- return None
- input_fn = cast(object, input_tool.get("function"))
- if isinstance(input_fn, PydanticFunctionTool):
- return model_parse_json(input_fn.model, function.arguments)
- input_fn = cast(FunctionDefinition, input_fn)
- if not input_fn.get("strict"):
- return None
- return json.loads(function.arguments) # type: ignore[no-any-return]
- def maybe_parse_content(
- *,
- response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
- message: ChatCompletionMessage | ParsedChatCompletionMessage[object],
- ) -> ResponseFormatT | None:
- if has_rich_response_format(response_format) and message.content and not message.refusal:
- return _parse_content(response_format, message.content)
- return None
- def solve_response_format_t(
- response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
- ) -> type[ResponseFormatT]:
- """Return the runtime type for the given response format.
- If no response format is given, or if we won't auto-parse the response format
- then we default to `None`.
- """
- if has_rich_response_format(response_format):
- return response_format
- return cast("type[ResponseFormatT]", _default_response_format)
- def has_parseable_input(
- *,
- response_format: type | ResponseFormatParam | Omit,
- input_tools: Iterable[ChatCompletionToolUnionParam] | Omit = omit,
- ) -> bool:
- if has_rich_response_format(response_format):
- return True
- for input_tool in input_tools or []:
- if is_parseable_tool(input_tool):
- return True
- return False
- def has_rich_response_format(
- response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
- ) -> TypeGuard[type[ResponseFormatT]]:
- if not is_given(response_format):
- return False
- if is_response_format_param(response_format):
- return False
- return True
- def is_response_format_param(response_format: object) -> TypeGuard[ResponseFormatParam]:
- return is_dict(response_format)
- def is_parseable_tool(input_tool: ChatCompletionToolUnionParam) -> bool:
- if input_tool["type"] != "function":
- return False
- input_fn = cast(object, input_tool.get("function"))
- if isinstance(input_fn, PydanticFunctionTool):
- return True
- return cast(FunctionDefinition, input_fn).get("strict") or False
- def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
- if is_basemodel_type(response_format):
- return cast(ResponseFormatT, model_parse_json(response_format, content))
- if is_dataclass_like_type(response_format):
- if PYDANTIC_V1:
- raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {response_format}")
- return pydantic.TypeAdapter(response_format).validate_json(content)
- raise TypeError(f"Unable to automatically parse response format type {response_format}")
- def type_to_response_format_param(
- response_format: type | completion_create_params.ResponseFormat | Omit,
- ) -> ResponseFormatParam | Omit:
- if not is_given(response_format):
- return omit
- if is_response_format_param(response_format):
- return response_format
- # type checkers don't narrow the negation of a `TypeGuard` as it isn't
- # a safe default behaviour but we know that at this point the `response_format`
- # can only be a `type`
- response_format = cast(type, response_format)
- json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
- if is_basemodel_type(response_format):
- name = response_format.__name__
- json_schema_type = response_format
- elif is_dataclass_like_type(response_format):
- name = response_format.__name__
- json_schema_type = pydantic.TypeAdapter(response_format)
- else:
- raise TypeError(f"Unsupported response_format type - {response_format}")
- return {
- "type": "json_schema",
- "json_schema": {
- "schema": to_strict_json_schema(json_schema_type),
- "name": name,
- "strict": True,
- },
- }
|