completions.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from __future__ import annotations
  2. import sys
  3. from typing import TYPE_CHECKING, Optional, cast
  4. from argparse import ArgumentParser
  5. from functools import partial
  6. from openai.types.completion import Completion
  7. from .._utils import get_client
  8. from ..._types import Omittable, omit
  9. from ..._utils import is_given
  10. from .._errors import CLIError
  11. from .._models import BaseModel
  12. from ..._streaming import Stream
  13. if TYPE_CHECKING:
  14. from argparse import _SubParsersAction
  15. def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
  16. sub = subparser.add_parser("completions.create")
  17. # Required
  18. sub.add_argument(
  19. "-m",
  20. "--model",
  21. help="The model to use",
  22. required=True,
  23. )
  24. # Optional
  25. sub.add_argument("-p", "--prompt", help="An optional prompt to complete from")
  26. sub.add_argument("--stream", help="Stream tokens as they're ready.", action="store_true")
  27. sub.add_argument("-M", "--max-tokens", help="The maximum number of tokens to generate", type=int)
  28. sub.add_argument(
  29. "-t",
  30. "--temperature",
  31. help="""What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
  32. Mutually exclusive with `top_p`.""",
  33. type=float,
  34. )
  35. sub.add_argument(
  36. "-P",
  37. "--top_p",
  38. help="""An alternative to sampling with temperature, called nucleus sampling, where the considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%% probability mass are considered.
  39. Mutually exclusive with `temperature`.""",
  40. type=float,
  41. )
  42. sub.add_argument(
  43. "-n",
  44. "--n",
  45. help="How many sub-completions to generate for each prompt.",
  46. type=int,
  47. )
  48. sub.add_argument(
  49. "--logprobs",
  50. help="Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. So for example, if `logprobs` is 10, the API will return a list of the 10 most likely tokens. If `logprobs` is 0, only the chosen tokens will have logprobs returned.",
  51. type=int,
  52. )
  53. sub.add_argument(
  54. "--best_of",
  55. help="Generates `best_of` completions server-side and returns the 'best' (the one with the highest log probability per token). Results cannot be streamed.",
  56. type=int,
  57. )
  58. sub.add_argument(
  59. "--echo",
  60. help="Echo back the prompt in addition to the completion",
  61. action="store_true",
  62. )
  63. sub.add_argument(
  64. "--frequency_penalty",
  65. help="Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.",
  66. type=float,
  67. )
  68. sub.add_argument(
  69. "--presence_penalty",
  70. help="Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.",
  71. type=float,
  72. )
  73. sub.add_argument("--suffix", help="The suffix that comes after a completion of inserted text.")
  74. sub.add_argument("--stop", help="A stop sequence at which to stop generating tokens.")
  75. sub.add_argument(
  76. "--user",
  77. help="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.",
  78. )
  79. # TODO: add support for logit_bias
  80. sub.set_defaults(func=CLICompletions.create, args_model=CLICompletionCreateArgs)
  81. class CLICompletionCreateArgs(BaseModel):
  82. model: str
  83. stream: bool = False
  84. prompt: Optional[str] = None
  85. n: Omittable[int] = omit
  86. stop: Omittable[str] = omit
  87. user: Omittable[str] = omit
  88. echo: Omittable[bool] = omit
  89. suffix: Omittable[str] = omit
  90. best_of: Omittable[int] = omit
  91. top_p: Omittable[float] = omit
  92. logprobs: Omittable[int] = omit
  93. max_tokens: Omittable[int] = omit
  94. temperature: Omittable[float] = omit
  95. presence_penalty: Omittable[float] = omit
  96. frequency_penalty: Omittable[float] = omit
  97. class CLICompletions:
  98. @staticmethod
  99. def create(args: CLICompletionCreateArgs) -> None:
  100. if is_given(args.n) and args.n > 1 and args.stream:
  101. raise CLIError("Can't stream completions with n>1 with the current CLI")
  102. make_request = partial(
  103. get_client().completions.create,
  104. n=args.n,
  105. echo=args.echo,
  106. stop=args.stop,
  107. user=args.user,
  108. model=args.model,
  109. top_p=args.top_p,
  110. prompt=args.prompt,
  111. suffix=args.suffix,
  112. best_of=args.best_of,
  113. logprobs=args.logprobs,
  114. max_tokens=args.max_tokens,
  115. temperature=args.temperature,
  116. presence_penalty=args.presence_penalty,
  117. frequency_penalty=args.frequency_penalty,
  118. )
  119. if args.stream:
  120. return CLICompletions._stream_create(
  121. # mypy doesn't understand the `partial` function but pyright does
  122. cast(Stream[Completion], make_request(stream=True)) # pyright: ignore[reportUnnecessaryCast]
  123. )
  124. return CLICompletions._create(make_request())
  125. @staticmethod
  126. def _create(completion: Completion) -> None:
  127. should_print_header = len(completion.choices) > 1
  128. for choice in completion.choices:
  129. if should_print_header:
  130. sys.stdout.write("===== Completion {} =====\n".format(choice.index))
  131. sys.stdout.write(choice.text)
  132. if should_print_header or not choice.text.endswith("\n"):
  133. sys.stdout.write("\n")
  134. sys.stdout.flush()
  135. @staticmethod
  136. def _stream_create(stream: Stream[Completion]) -> None:
  137. for completion in stream:
  138. should_print_header = len(completion.choices) > 1
  139. for choice in sorted(completion.choices, key=lambda c: c.index):
  140. if should_print_header:
  141. sys.stdout.write("===== Chat Completion {} =====\n".format(choice.index))
  142. sys.stdout.write(choice.text)
  143. if should_print_header:
  144. sys.stdout.write("\n")
  145. sys.stdout.flush()
  146. sys.stdout.write("\n")