image.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING, Any, cast
  3. from argparse import ArgumentParser
  4. from .._utils import get_client, print_model
  5. from ..._types import Omit, Omittable, omit
  6. from .._models import BaseModel
  7. from .._progress import BufferReader
  8. if TYPE_CHECKING:
  9. from argparse import _SubParsersAction
  10. def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
  11. sub = subparser.add_parser("images.generate")
  12. sub.add_argument("-m", "--model", type=str)
  13. sub.add_argument("-p", "--prompt", type=str, required=True)
  14. sub.add_argument("-n", "--num-images", type=int, default=1)
  15. sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
  16. sub.add_argument("--response-format", type=str, default="url")
  17. sub.set_defaults(func=CLIImage.create, args_model=CLIImageCreateArgs)
  18. sub = subparser.add_parser("images.edit")
  19. sub.add_argument("-m", "--model", type=str)
  20. sub.add_argument("-p", "--prompt", type=str, required=True)
  21. sub.add_argument("-n", "--num-images", type=int, default=1)
  22. sub.add_argument(
  23. "-I",
  24. "--image",
  25. type=str,
  26. required=True,
  27. help="Image to modify. Should be a local path and a PNG encoded image.",
  28. )
  29. sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
  30. sub.add_argument("--response-format", type=str, default="url")
  31. sub.add_argument(
  32. "-M",
  33. "--mask",
  34. type=str,
  35. required=False,
  36. help="Path to a mask image. It should be the same size as the image you're editing and a RGBA PNG image. The Alpha channel acts as the mask.",
  37. )
  38. sub.set_defaults(func=CLIImage.edit, args_model=CLIImageEditArgs)
  39. sub = subparser.add_parser("images.create_variation")
  40. sub.add_argument("-m", "--model", type=str)
  41. sub.add_argument("-n", "--num-images", type=int, default=1)
  42. sub.add_argument(
  43. "-I",
  44. "--image",
  45. type=str,
  46. required=True,
  47. help="Image to modify. Should be a local path and a PNG encoded image.",
  48. )
  49. sub.add_argument("-s", "--size", type=str, default="1024x1024", help="Size of the output image")
  50. sub.add_argument("--response-format", type=str, default="url")
  51. sub.set_defaults(func=CLIImage.create_variation, args_model=CLIImageCreateVariationArgs)
  52. class CLIImageCreateArgs(BaseModel):
  53. prompt: str
  54. num_images: int
  55. size: str
  56. response_format: str
  57. model: Omittable[str] = omit
  58. class CLIImageCreateVariationArgs(BaseModel):
  59. image: str
  60. num_images: int
  61. size: str
  62. response_format: str
  63. model: Omittable[str] = omit
  64. class CLIImageEditArgs(BaseModel):
  65. image: str
  66. num_images: int
  67. size: str
  68. response_format: str
  69. prompt: str
  70. mask: Omittable[str] = omit
  71. model: Omittable[str] = omit
  72. class CLIImage:
  73. @staticmethod
  74. def create(args: CLIImageCreateArgs) -> None:
  75. image = get_client().images.generate(
  76. model=args.model,
  77. prompt=args.prompt,
  78. n=args.num_images,
  79. # casts required because the API is typed for enums
  80. # but we don't want to validate that here for forwards-compat
  81. size=cast(Any, args.size),
  82. response_format=cast(Any, args.response_format),
  83. )
  84. print_model(image)
  85. @staticmethod
  86. def create_variation(args: CLIImageCreateVariationArgs) -> None:
  87. with open(args.image, "rb") as file_reader:
  88. buffer_reader = BufferReader(file_reader.read(), desc="Upload progress")
  89. image = get_client().images.create_variation(
  90. model=args.model,
  91. image=("image", buffer_reader),
  92. n=args.num_images,
  93. # casts required because the API is typed for enums
  94. # but we don't want to validate that here for forwards-compat
  95. size=cast(Any, args.size),
  96. response_format=cast(Any, args.response_format),
  97. )
  98. print_model(image)
  99. @staticmethod
  100. def edit(args: CLIImageEditArgs) -> None:
  101. with open(args.image, "rb") as file_reader:
  102. buffer_reader = BufferReader(file_reader.read(), desc="Image upload progress")
  103. if isinstance(args.mask, Omit):
  104. mask: Omittable[BufferReader] = omit
  105. else:
  106. with open(args.mask, "rb") as file_reader:
  107. mask = BufferReader(file_reader.read(), desc="Mask progress")
  108. image = get_client().images.edit(
  109. model=args.model,
  110. prompt=args.prompt,
  111. image=("image", buffer_reader),
  112. n=args.num_images,
  113. mask=("mask", mask) if not isinstance(mask, Omit) else mask,
  114. # casts required because the API is typed for enums
  115. # but we don't want to validate that here for forwards-compat
  116. size=cast(Any, args.size),
  117. response_format=cast(Any, args.response_format),
  118. )
  119. print_model(image)