migrate.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from __future__ import annotations
  2. import os
  3. import sys
  4. import shutil
  5. import tarfile
  6. import platform
  7. import subprocess
  8. from typing import TYPE_CHECKING, List
  9. from pathlib import Path
  10. from argparse import ArgumentParser
  11. import httpx
  12. from .._errors import CLIError, SilentCLIError
  13. from .._models import BaseModel
  14. if TYPE_CHECKING:
  15. from argparse import _SubParsersAction
  16. def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
  17. sub = subparser.add_parser("migrate")
  18. sub.set_defaults(func=migrate, args_model=MigrateArgs, allow_unknown_args=True)
  19. sub = subparser.add_parser("grit")
  20. sub.set_defaults(func=grit, args_model=GritArgs, allow_unknown_args=True)
  21. class GritArgs(BaseModel):
  22. # internal
  23. unknown_args: List[str] = []
  24. def grit(args: GritArgs) -> None:
  25. grit_path = install()
  26. try:
  27. subprocess.check_call([grit_path, *args.unknown_args])
  28. except subprocess.CalledProcessError:
  29. # stdout and stderr are forwarded by subprocess so an error will already
  30. # have been displayed
  31. raise SilentCLIError() from None
  32. class MigrateArgs(BaseModel):
  33. # internal
  34. unknown_args: List[str] = []
  35. def migrate(args: MigrateArgs) -> None:
  36. grit_path = install()
  37. try:
  38. subprocess.check_call([grit_path, "apply", "openai", *args.unknown_args])
  39. except subprocess.CalledProcessError:
  40. # stdout and stderr are forwarded by subprocess so an error will already
  41. # have been displayed
  42. raise SilentCLIError() from None
  43. # handles downloading the Grit CLI until they provide their own PyPi package
  44. KEYGEN_ACCOUNT = "custodian-dev"
  45. def _cache_dir() -> Path:
  46. xdg = os.environ.get("XDG_CACHE_HOME")
  47. if xdg is not None:
  48. return Path(xdg)
  49. return Path.home() / ".cache"
  50. def _debug(message: str) -> None:
  51. if not os.environ.get("DEBUG"):
  52. return
  53. sys.stdout.write(f"[DEBUG]: {message}\n")
  54. def install() -> Path:
  55. """Installs the Grit CLI and returns the location of the binary"""
  56. if sys.platform == "win32":
  57. raise CLIError("Windows is not supported yet in the migration CLI")
  58. _debug("Using Grit installer from GitHub")
  59. platform = "apple-darwin" if sys.platform == "darwin" else "unknown-linux-gnu"
  60. dir_name = _cache_dir() / "openai-python"
  61. install_dir = dir_name / ".install"
  62. target_dir = install_dir / "bin"
  63. target_path = target_dir / "grit"
  64. temp_file = target_dir / "grit.tmp"
  65. if target_path.exists():
  66. _debug(f"{target_path} already exists")
  67. sys.stdout.flush()
  68. return target_path
  69. _debug(f"Using Grit CLI path: {target_path}")
  70. target_dir.mkdir(parents=True, exist_ok=True)
  71. if temp_file.exists():
  72. temp_file.unlink()
  73. arch = _get_arch()
  74. _debug(f"Using architecture {arch}")
  75. file_name = f"grit-{arch}-{platform}"
  76. download_url = f"https://github.com/getgrit/gritql/releases/latest/download/{file_name}.tar.gz"
  77. sys.stdout.write(f"Downloading Grit CLI from {download_url}\n")
  78. with httpx.Client() as client:
  79. download_response = client.get(download_url, follow_redirects=True)
  80. if download_response.status_code != 200:
  81. raise CLIError(f"Failed to download Grit CLI from {download_url}")
  82. with open(temp_file, "wb") as file:
  83. for chunk in download_response.iter_bytes():
  84. file.write(chunk)
  85. unpacked_dir = target_dir / "cli-bin"
  86. unpacked_dir.mkdir(parents=True, exist_ok=True)
  87. with tarfile.open(temp_file, "r:gz") as archive:
  88. if sys.version_info >= (3, 12):
  89. archive.extractall(unpacked_dir, filter="data")
  90. else:
  91. archive.extractall(unpacked_dir)
  92. _move_files_recursively(unpacked_dir, target_dir)
  93. shutil.rmtree(unpacked_dir)
  94. os.remove(temp_file)
  95. os.chmod(target_path, 0o755)
  96. sys.stdout.flush()
  97. return target_path
  98. def _move_files_recursively(source_dir: Path, target_dir: Path) -> None:
  99. for item in source_dir.iterdir():
  100. if item.is_file():
  101. item.rename(target_dir / item.name)
  102. elif item.is_dir():
  103. _move_files_recursively(item, target_dir)
  104. def _get_arch() -> str:
  105. architecture = platform.machine().lower()
  106. # Map the architecture names to Grit equivalents
  107. arch_map = {
  108. "x86_64": "x86_64",
  109. "amd64": "x86_64",
  110. "armv7l": "aarch64",
  111. "arm64": "aarch64",
  112. }
  113. return arch_map.get(architecture, architecture)