_pallas.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import functools
  2. import torch
  3. @functools.cache
  4. def has_jax_package() -> bool:
  5. """Check if JAX is installed."""
  6. try:
  7. import jax # noqa: F401 # type: ignore[import-not-found]
  8. return True
  9. except ImportError:
  10. return False
  11. @functools.cache
  12. def has_pallas_package() -> bool:
  13. """Check if Pallas (JAX experimental) is available."""
  14. if not has_jax_package():
  15. return False
  16. try:
  17. from jax.experimental import ( # noqa: F401 # type: ignore[import-not-found]
  18. pallas as pl,
  19. )
  20. return True
  21. except ImportError:
  22. return False
  23. @functools.cache
  24. def get_jax_version(fallback: tuple[int, int, int] = (0, 0, 0)) -> tuple[int, int, int]:
  25. """Get JAX version as (major, minor, patch) tuple."""
  26. try:
  27. import jax # type: ignore[import-not-found]
  28. version_parts = jax.__version__.split(".")
  29. major, minor, patch = (int(v) for v in version_parts[:3])
  30. return (major, minor, patch)
  31. except (ImportError, ValueError, AttributeError):
  32. return fallback
  33. @functools.cache
  34. def has_jax_cuda_backend() -> bool:
  35. """Check if JAX has CUDA backend support."""
  36. if not has_jax_package():
  37. return False
  38. try:
  39. import jax # type: ignore[import-not-found]
  40. # Check if CUDA backend is available
  41. devices = jax.devices("gpu")
  42. return len(devices) > 0
  43. except Exception:
  44. return False
  45. @functools.cache
  46. def has_jax_tpu_backend() -> bool:
  47. """Check if JAX has TPU backend support."""
  48. if not has_jax_package():
  49. return False
  50. try:
  51. import jax # type: ignore[import-not-found]
  52. # Check if TPU backend is available
  53. devices = jax.devices("tpu")
  54. return len(devices) > 0
  55. except Exception:
  56. return False
  57. @functools.cache
  58. def has_cpu_pallas() -> bool:
  59. """Checks for a full Pallas-on-CPU environment."""
  60. return has_pallas_package()
  61. @functools.cache
  62. def has_cuda_pallas() -> bool:
  63. """Checks for a full Pallas-on-CUDA environment."""
  64. return has_pallas_package() and torch.cuda.is_available() and has_jax_cuda_backend()
  65. @functools.cache
  66. def has_tpu_pallas() -> bool:
  67. """Checks for a full Pallas-on-TPU environment."""
  68. return has_pallas_package() and has_jax_tpu_backend()
  69. @functools.cache
  70. def has_pallas() -> bool:
  71. """
  72. Check if Pallas backend is fully available for use.
  73. Requirements:
  74. - JAX package installed
  75. - Pallas (jax.experimental.pallas) available
  76. - A compatible backend (CUDA or TPU) is available in both PyTorch and JAX.
  77. """
  78. return has_cpu_pallas() or has_cuda_pallas() or has_tpu_pallas()