completions.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from __future__ import annotations
  2. import sys
  3. from typing import TYPE_CHECKING, List, Optional, cast
  4. from argparse import ArgumentParser
  5. from typing_extensions import Literal, NamedTuple
  6. from ..._utils import get_client
  7. from ..._models import BaseModel
  8. from ...._streaming import Stream
  9. from ....types.chat import (
  10. ChatCompletionRole,
  11. ChatCompletionChunk,
  12. CompletionCreateParams,
  13. )
  14. from ....types.chat.completion_create_params import (
  15. CompletionCreateParamsStreaming,
  16. CompletionCreateParamsNonStreaming,
  17. )
  18. if TYPE_CHECKING:
  19. from argparse import _SubParsersAction
  20. def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
  21. sub = subparser.add_parser("chat.completions.create")
  22. sub._action_groups.pop()
  23. req = sub.add_argument_group("required arguments")
  24. opt = sub.add_argument_group("optional arguments")
  25. req.add_argument(
  26. "-g",
  27. "--message",
  28. action="append",
  29. nargs=2,
  30. metavar=("ROLE", "CONTENT"),
  31. help="A message in `{role} {content}` format. Use this argument multiple times to add multiple messages.",
  32. required=True,
  33. )
  34. req.add_argument(
  35. "-m",
  36. "--model",
  37. help="The model to use.",
  38. required=True,
  39. )
  40. opt.add_argument(
  41. "-n",
  42. "--n",
  43. help="How many completions to generate for the conversation.",
  44. type=int,
  45. )
  46. opt.add_argument("-M", "--max-tokens", help="The maximum number of tokens to generate.", type=int)
  47. opt.add_argument(
  48. "-t",
  49. "--temperature",
  50. help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
  51. Mutually exclusive with `top_p`.""",
  52. type=float,
  53. )
  54. opt.add_argument(
  55. "-P",
  56. "--top_p",
  57. help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
  58. Mutually exclusive with `temperature`.""",
  59. type=float,
  60. )
  61. opt.add_argument(
  62. "--stop",
  63. help="A stop sequence at which to stop generating tokens for the message.",
  64. )
  65. opt.add_argument("--stream", help="Stream messages as they're ready.", action="store_true")
  66. sub.set_defaults(func=CLIChatCompletion.create, args_model=CLIChatCompletionCreateArgs)
  67. class CLIMessage(NamedTuple):
  68. role: ChatCompletionRole
  69. content: str
  70. class CLIChatCompletionCreateArgs(BaseModel):
  71. message: List[CLIMessage]
  72. model: str
  73. n: Optional[int] = None
  74. max_tokens: Optional[int] = None
  75. temperature: Optional[float] = None
  76. top_p: Optional[float] = None
  77. stop: Optional[str] = None
  78. stream: bool = False
  79. class CLIChatCompletion:
  80. @staticmethod
  81. def create(args: CLIChatCompletionCreateArgs) -> None:
  82. params: CompletionCreateParams = {
  83. "model": args.model,
  84. "messages": [
  85. {"role": cast(Literal["user"], message.role), "content": message.content} for message in args.message
  86. ],
  87. # type checkers are not good at inferring union types so we have to set stream afterwards
  88. "stream": False,
  89. }
  90. if args.temperature is not None:
  91. params["temperature"] = args.temperature
  92. if args.stop is not None:
  93. params["stop"] = args.stop
  94. if args.top_p is not None:
  95. params["top_p"] = args.top_p
  96. if args.n is not None:
  97. params["n"] = args.n
  98. if args.stream:
  99. params["stream"] = args.stream # type: ignore
  100. if args.max_tokens is not None:
  101. params["max_tokens"] = args.max_tokens
  102. if args.stream:
  103. return CLIChatCompletion._stream_create(cast(CompletionCreateParamsStreaming, params))
  104. return CLIChatCompletion._create(cast(CompletionCreateParamsNonStreaming, params))
  105. @staticmethod
  106. def _create(params: CompletionCreateParamsNonStreaming) -> None:
  107. completion = get_client().chat.completions.create(**params)
  108. should_print_header = len(completion.choices) > 1
  109. for choice in completion.choices:
  110. if should_print_header:
  111. sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
  112. content = choice.message.content if choice.message.content is not None else "None"
  113. sys.stdout.write(content)
  114. if should_print_header or not content.endswith("\n"):
  115. sys.stdout.write("\n")
  116. sys.stdout.flush()
  117. @staticmethod
  118. def _stream_create(params: CompletionCreateParamsStreaming) -> None:
  119. # cast is required for mypy
  120. stream = cast( # pyright: ignore[reportUnnecessaryCast]
  121. Stream[ChatCompletionChunk], get_client().chat.completions.create(**params)
  122. )
  123. for chunk in stream:
  124. should_print_header = len(chunk.choices) > 1
  125. for choice in chunk.choices:
  126. if should_print_header:
  127. sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
  128. content = choice.delta.content or ""
  129. sys.stdout.write(content)
  130. if should_print_header:
  131. sys.stdout.write("\n")
  132. sys.stdout.flush()
  133. sys.stdout.write("\n")