_tools.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. from __future__ import annotations
  2. from typing import Any, Dict, cast
  3. import pydantic
  4. from ._pydantic import to_strict_json_schema
  5. from ..types.chat import ChatCompletionFunctionToolParam
  6. from ..types.shared_params import FunctionDefinition
  7. from ..types.responses.function_tool_param import FunctionToolParam as ResponsesFunctionToolParam
  8. class PydanticFunctionTool(Dict[str, Any]):
  9. """Dictionary wrapper so we can pass the given base model
  10. throughout the entire request stack without having to special
  11. case it.
  12. """
  13. model: type[pydantic.BaseModel]
  14. def __init__(self, defn: FunctionDefinition, model: type[pydantic.BaseModel]) -> None:
  15. super().__init__(defn)
  16. self.model = model
  17. def cast(self) -> FunctionDefinition:
  18. return cast(FunctionDefinition, self)
  19. class ResponsesPydanticFunctionTool(Dict[str, Any]):
  20. model: type[pydantic.BaseModel]
  21. def __init__(self, tool: ResponsesFunctionToolParam, model: type[pydantic.BaseModel]) -> None:
  22. super().__init__(tool)
  23. self.model = model
  24. def cast(self) -> ResponsesFunctionToolParam:
  25. return cast(ResponsesFunctionToolParam, self)
  26. def pydantic_function_tool(
  27. model: type[pydantic.BaseModel],
  28. *,
  29. name: str | None = None, # inferred from class name by default
  30. description: str | None = None, # inferred from class docstring by default
  31. ) -> ChatCompletionFunctionToolParam:
  32. if description is None:
  33. # note: we intentionally don't use `.getdoc()` to avoid
  34. # including pydantic's docstrings
  35. description = model.__doc__
  36. function = PydanticFunctionTool(
  37. {
  38. "name": name or model.__name__,
  39. "strict": True,
  40. "parameters": to_strict_json_schema(model),
  41. },
  42. model,
  43. ).cast()
  44. if description is not None:
  45. function["description"] = description
  46. return {
  47. "type": "function",
  48. "function": function,
  49. }