_completions.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. from __future__ import annotations
  2. import json
  3. import logging
  4. from typing import TYPE_CHECKING, Any, Iterable, cast
  5. from typing_extensions import TypeVar, TypeGuard, assert_never
  6. import pydantic
  7. from .._tools import PydanticFunctionTool
  8. from ..._types import Omit, omit
  9. from ..._utils import is_dict, is_given
  10. from ..._compat import PYDANTIC_V1, model_parse_json
  11. from ..._models import construct_type_unchecked
  12. from .._pydantic import is_basemodel_type, to_strict_json_schema, is_dataclass_like_type
  13. from ...types.chat import (
  14. ParsedChoice,
  15. ChatCompletion,
  16. ParsedFunction,
  17. ParsedChatCompletion,
  18. ChatCompletionMessage,
  19. ParsedFunctionToolCall,
  20. ParsedChatCompletionMessage,
  21. ChatCompletionToolUnionParam,
  22. ChatCompletionFunctionToolParam,
  23. completion_create_params,
  24. )
  25. from ..._exceptions import LengthFinishReasonError, ContentFilterFinishReasonError
  26. from ...types.shared_params import FunctionDefinition
  27. from ...types.chat.completion_create_params import ResponseFormat as ResponseFormatParam
  28. from ...types.chat.chat_completion_message_function_tool_call import Function
  29. ResponseFormatT = TypeVar(
  30. "ResponseFormatT",
  31. # if it isn't given then we don't do any parsing
  32. default=None,
  33. )
  34. _default_response_format: None = None
  35. log: logging.Logger = logging.getLogger("openai.lib.parsing")
  36. def is_strict_chat_completion_tool_param(
  37. tool: ChatCompletionToolUnionParam,
  38. ) -> TypeGuard[ChatCompletionFunctionToolParam]:
  39. """Check if the given tool is a strict ChatCompletionFunctionToolParam."""
  40. if not tool["type"] == "function":
  41. return False
  42. if tool["function"].get("strict") is not True:
  43. return False
  44. return True
  45. def select_strict_chat_completion_tools(
  46. tools: Iterable[ChatCompletionToolUnionParam] | Omit = omit,
  47. ) -> Iterable[ChatCompletionFunctionToolParam] | Omit:
  48. """Select only the strict ChatCompletionFunctionToolParams from the given tools."""
  49. if not is_given(tools):
  50. return omit
  51. return [t for t in tools if is_strict_chat_completion_tool_param(t)]
  52. def validate_input_tools(
  53. tools: Iterable[ChatCompletionToolUnionParam] | Omit = omit,
  54. ) -> Iterable[ChatCompletionFunctionToolParam] | Omit:
  55. if not is_given(tools):
  56. return omit
  57. for tool in tools:
  58. if tool["type"] != "function":
  59. raise ValueError(
  60. f"Currently only `function` tool types support auto-parsing; Received `{tool['type']}`",
  61. )
  62. strict = tool["function"].get("strict")
  63. if strict is not True:
  64. raise ValueError(
  65. f"`{tool['function']['name']}` is not strict. Only `strict` function tools can be auto-parsed"
  66. )
  67. return cast(Iterable[ChatCompletionFunctionToolParam], tools)
  68. def parse_chat_completion(
  69. *,
  70. response_format: type[ResponseFormatT] | completion_create_params.ResponseFormat | Omit,
  71. input_tools: Iterable[ChatCompletionToolUnionParam] | Omit,
  72. chat_completion: ChatCompletion | ParsedChatCompletion[object],
  73. ) -> ParsedChatCompletion[ResponseFormatT]:
  74. if is_given(input_tools):
  75. input_tools = [t for t in input_tools]
  76. else:
  77. input_tools = []
  78. choices: list[ParsedChoice[ResponseFormatT]] = []
  79. for choice in chat_completion.choices:
  80. if choice.finish_reason == "length":
  81. raise LengthFinishReasonError(completion=chat_completion)
  82. if choice.finish_reason == "content_filter":
  83. raise ContentFilterFinishReasonError()
  84. message = choice.message
  85. tool_calls: list[ParsedFunctionToolCall] = []
  86. if message.tool_calls:
  87. for tool_call in message.tool_calls:
  88. if tool_call.type == "function":
  89. tool_call_dict = tool_call.to_dict()
  90. tool_calls.append(
  91. construct_type_unchecked(
  92. value={
  93. **tool_call_dict,
  94. "function": {
  95. **cast(Any, tool_call_dict["function"]),
  96. "parsed_arguments": parse_function_tool_arguments(
  97. input_tools=input_tools, function=tool_call.function
  98. ),
  99. },
  100. },
  101. type_=ParsedFunctionToolCall,
  102. )
  103. )
  104. elif tool_call.type == "custom":
  105. # warn user that custom tool calls are not callable here
  106. log.warning(
  107. "Custom tool calls are not callable. Ignoring tool call: %s - %s",
  108. tool_call.id,
  109. tool_call.custom.name,
  110. stacklevel=2,
  111. )
  112. elif TYPE_CHECKING: # type: ignore[unreachable]
  113. assert_never(tool_call)
  114. else:
  115. tool_calls.append(tool_call)
  116. choices.append(
  117. construct_type_unchecked(
  118. type_=cast(Any, ParsedChoice)[solve_response_format_t(response_format)],
  119. value={
  120. **choice.to_dict(),
  121. "message": {
  122. **message.to_dict(),
  123. "parsed": maybe_parse_content(
  124. response_format=response_format,
  125. message=message,
  126. ),
  127. "tool_calls": tool_calls if tool_calls else None,
  128. },
  129. },
  130. )
  131. )
  132. return cast(
  133. ParsedChatCompletion[ResponseFormatT],
  134. construct_type_unchecked(
  135. type_=cast(Any, ParsedChatCompletion)[solve_response_format_t(response_format)],
  136. value={
  137. **chat_completion.to_dict(),
  138. "choices": choices,
  139. },
  140. ),
  141. )
  142. def get_input_tool_by_name(
  143. *, input_tools: list[ChatCompletionToolUnionParam], name: str
  144. ) -> ChatCompletionFunctionToolParam | None:
  145. return next((t for t in input_tools if t["type"] == "function" and t.get("function", {}).get("name") == name), None)
  146. def parse_function_tool_arguments(
  147. *, input_tools: list[ChatCompletionToolUnionParam], function: Function | ParsedFunction
  148. ) -> object | None:
  149. input_tool = get_input_tool_by_name(input_tools=input_tools, name=function.name)
  150. if not input_tool:
  151. return None
  152. input_fn = cast(object, input_tool.get("function"))
  153. if isinstance(input_fn, PydanticFunctionTool):
  154. return model_parse_json(input_fn.model, function.arguments)
  155. input_fn = cast(FunctionDefinition, input_fn)
  156. if not input_fn.get("strict"):
  157. return None
  158. return json.loads(function.arguments) # type: ignore[no-any-return]
  159. def maybe_parse_content(
  160. *,
  161. response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
  162. message: ChatCompletionMessage | ParsedChatCompletionMessage[object],
  163. ) -> ResponseFormatT | None:
  164. if has_rich_response_format(response_format) and message.content and not message.refusal:
  165. return _parse_content(response_format, message.content)
  166. return None
  167. def solve_response_format_t(
  168. response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
  169. ) -> type[ResponseFormatT]:
  170. """Return the runtime type for the given response format.
  171. If no response format is given, or if we won't auto-parse the response format
  172. then we default to `None`.
  173. """
  174. if has_rich_response_format(response_format):
  175. return response_format
  176. return cast("type[ResponseFormatT]", _default_response_format)
  177. def has_parseable_input(
  178. *,
  179. response_format: type | ResponseFormatParam | Omit,
  180. input_tools: Iterable[ChatCompletionToolUnionParam] | Omit = omit,
  181. ) -> bool:
  182. if has_rich_response_format(response_format):
  183. return True
  184. for input_tool in input_tools or []:
  185. if is_parseable_tool(input_tool):
  186. return True
  187. return False
  188. def has_rich_response_format(
  189. response_format: type[ResponseFormatT] | ResponseFormatParam | Omit,
  190. ) -> TypeGuard[type[ResponseFormatT]]:
  191. if not is_given(response_format):
  192. return False
  193. if is_response_format_param(response_format):
  194. return False
  195. return True
  196. def is_response_format_param(response_format: object) -> TypeGuard[ResponseFormatParam]:
  197. return is_dict(response_format)
  198. def is_parseable_tool(input_tool: ChatCompletionToolUnionParam) -> bool:
  199. if input_tool["type"] != "function":
  200. return False
  201. input_fn = cast(object, input_tool.get("function"))
  202. if isinstance(input_fn, PydanticFunctionTool):
  203. return True
  204. return cast(FunctionDefinition, input_fn).get("strict") or False
  205. def _parse_content(response_format: type[ResponseFormatT], content: str) -> ResponseFormatT:
  206. if is_basemodel_type(response_format):
  207. return cast(ResponseFormatT, model_parse_json(response_format, content))
  208. if is_dataclass_like_type(response_format):
  209. if PYDANTIC_V1:
  210. raise TypeError(f"Non BaseModel types are only supported with Pydantic v2 - {response_format}")
  211. return pydantic.TypeAdapter(response_format).validate_json(content)
  212. raise TypeError(f"Unable to automatically parse response format type {response_format}")
  213. def type_to_response_format_param(
  214. response_format: type | completion_create_params.ResponseFormat | Omit,
  215. ) -> ResponseFormatParam | Omit:
  216. if not is_given(response_format):
  217. return omit
  218. if is_response_format_param(response_format):
  219. return response_format
  220. # type checkers don't narrow the negation of a `TypeGuard` as it isn't
  221. # a safe default behaviour but we know that at this point the `response_format`
  222. # can only be a `type`
  223. response_format = cast(type, response_format)
  224. json_schema_type: type[pydantic.BaseModel] | pydantic.TypeAdapter[Any] | None = None
  225. if is_basemodel_type(response_format):
  226. name = response_format.__name__
  227. json_schema_type = response_format
  228. elif is_dataclass_like_type(response_format):
  229. name = response_format.__name__
  230. json_schema_type = pydantic.TypeAdapter(response_format)
  231. else:
  232. raise TypeError(f"Unsupported response_format type - {response_format}")
  233. return {
  234. "type": "json_schema",
  235. "json_schema": {
  236. "schema": to_strict_json_schema(json_schema_type),
  237. "name": name,
  238. "strict": True,
  239. },
  240. }