_triton.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import functools
  2. import hashlib
  3. from typing import Any
  4. @functools.cache
  5. def has_triton_package() -> bool:
  6. try:
  7. import triton # noqa: F401
  8. return True
  9. except ImportError:
  10. return False
  11. @functools.cache
  12. def get_triton_version(fallback: tuple[int, int] = (0, 0)) -> tuple[int, int]:
  13. try:
  14. import triton # noqa: F401
  15. major, minor = tuple(int(v) for v in triton.__version__.split(".")[:2])
  16. return (major, minor)
  17. except ImportError:
  18. return fallback
  19. @functools.cache
  20. def _device_supports_tma() -> bool:
  21. import torch
  22. return (
  23. torch.cuda.is_available()
  24. and torch.cuda.get_device_capability() >= (9, 0)
  25. and not torch.version.hip
  26. )
  27. @functools.cache
  28. def has_triton_experimental_host_tma() -> bool:
  29. if has_triton_package():
  30. if _device_supports_tma():
  31. try:
  32. from triton.tools.experimental_descriptor import ( # noqa: F401
  33. create_1d_tma_descriptor,
  34. create_2d_tma_descriptor,
  35. )
  36. return True
  37. except ImportError:
  38. pass
  39. return False
  40. @functools.cache
  41. def has_triton_tensor_descriptor_host_tma() -> bool:
  42. if has_triton_package():
  43. if _device_supports_tma():
  44. try:
  45. from triton.tools.tensor_descriptor import ( # noqa: F401
  46. TensorDescriptor,
  47. )
  48. return True
  49. except ImportError:
  50. pass
  51. return False
  52. @functools.cache
  53. def has_triton_tma() -> bool:
  54. return has_triton_tensor_descriptor_host_tma() or has_triton_experimental_host_tma()
  55. @functools.cache
  56. def has_triton_tma_device() -> bool:
  57. if has_triton_package():
  58. import torch
  59. if (
  60. torch.cuda.is_available()
  61. and torch.cuda.get_device_capability() >= (9, 0)
  62. and not torch.version.hip
  63. ) or torch.xpu.is_available():
  64. # old API
  65. try:
  66. from triton.language.extra.cuda import ( # noqa: F401
  67. experimental_device_tensormap_create1d,
  68. experimental_device_tensormap_create2d,
  69. )
  70. return True
  71. except ImportError:
  72. pass
  73. # new API
  74. try:
  75. from triton.language import make_tensor_descriptor # noqa: F401
  76. return True
  77. except ImportError:
  78. pass
  79. return False
  80. @functools.lru_cache(None)
  81. def has_triton_stable_tma_api() -> bool:
  82. if has_triton_package():
  83. import torch
  84. if (
  85. torch.cuda.is_available()
  86. and torch.cuda.get_device_capability() >= (9, 0)
  87. and not torch.version.hip
  88. ) or torch.xpu.is_available():
  89. try:
  90. from triton.language import make_tensor_descriptor # noqa: F401
  91. return True
  92. except ImportError:
  93. pass
  94. return False
  95. @functools.cache
  96. def has_triton() -> bool:
  97. if not has_triton_package():
  98. return False
  99. from torch._dynamo.device_interface import get_interface_for_device
  100. def cuda_extra_check(device_interface: Any) -> bool:
  101. return device_interface.Worker.get_device_properties().major >= 7
  102. def cpu_extra_check(device_interface: Any) -> bool:
  103. import triton.backends
  104. return "cpu" in triton.backends.backends
  105. def _return_true(device_interface: Any) -> bool:
  106. return True
  107. triton_supported_devices = {
  108. "cuda": cuda_extra_check,
  109. "xpu": _return_true,
  110. "cpu": cpu_extra_check,
  111. "mtia": _return_true,
  112. }
  113. def is_device_compatible_with_triton() -> bool:
  114. for device, extra_check in triton_supported_devices.items():
  115. device_interface = get_interface_for_device(device)
  116. if device_interface.is_available() and extra_check(device_interface):
  117. return True
  118. return False
  119. return is_device_compatible_with_triton()
  120. @functools.cache
  121. def triton_backend() -> Any:
  122. from triton.compiler.compiler import make_backend
  123. from triton.runtime.driver import driver
  124. target = driver.active.get_current_target()
  125. return make_backend(target)
  126. @functools.cache
  127. def triton_hash_with_backend() -> str:
  128. from torch._inductor.runtime.triton_compat import triton_key
  129. backend = triton_backend()
  130. key = f"{triton_key()}-{backend.hash()}"
  131. # Hash is upper case so that it can't contain any Python keywords.
  132. return hashlib.sha256(key.encode("utf-8")).hexdigest().upper()