_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. import ctypes
  2. import sys
  3. from typing import Any, Optional, Union
  4. import torch
  5. # The _get_device_index has been moved to torch.utils._get_device_index
  6. from torch._utils import _get_device_index as _torch_get_device_index
  7. # Load CUDA driver and NVRTC
  8. def _get_cuda_library() -> ctypes.CDLL:
  9. if sys.platform == "win32":
  10. return ctypes.CDLL("nvcuda.dll")
  11. else: # Unix-based systems
  12. return ctypes.CDLL("libcuda.so.1")
  13. # Helper: check CUDA errors
  14. def _check_cuda(result: int) -> None:
  15. if result == 0:
  16. return
  17. err_str = ctypes.c_char_p()
  18. libcuda = _get_cuda_library() # Get reference to CUDA library
  19. libcuda.cuGetErrorString(result, ctypes.byref(err_str))
  20. error_message = (
  21. err_str.value.decode() if err_str.value is not None else "Unknown CUDA error"
  22. )
  23. raise RuntimeError(f"CUDA error: {error_message}")
  24. def _get_nvrtc_library() -> ctypes.CDLL:
  25. major_version = int(torch.version.cuda.split(".")[0]) # type: ignore[union-attr]
  26. if sys.platform == "win32":
  27. nvrtc_libs = [
  28. f"nvrtc64_{major_version}0_0.dll",
  29. ]
  30. else:
  31. nvrtc_libs = [
  32. f"libnvrtc.so.{major_version}",
  33. "libnvrtc.so", # Fallback to unversioned
  34. ]
  35. for lib_name in nvrtc_libs:
  36. try:
  37. return ctypes.CDLL(lib_name)
  38. except OSError:
  39. continue
  40. raise OSError("Could not find any NVRTC library")
  41. def _nvrtc_compile(
  42. kernel_source: str,
  43. kernel_name: str,
  44. compute_capability: Optional[str] = None,
  45. header_code: str = "",
  46. cuda_include_dirs: Optional[list] = None,
  47. nvcc_options: Optional[list] = None,
  48. ) -> bytes:
  49. """
  50. Compiles a CUDA kernel using NVRTC and returns the PTX code.
  51. Args:
  52. kernel_source (str): The CUDA kernel source code as a string
  53. kernel_name (str): The name of the kernel function to compile
  54. compute_capability (str, None): The compute capability to target (e.g., "86").
  55. If None, will detect from current device.
  56. header_code (str, optional): Additional header code to prepend to the kernel source
  57. cuda_include_dirs (list, None): List of directories containing CUDA headers
  58. nvcc_options (list, None): Additional options to pass to NVRTC
  59. Returns:
  60. str: The compiled PTX code
  61. """
  62. # Ensure CUDA is initialized
  63. import torch.cuda
  64. # Load NVRTC library
  65. libnvrtc = _get_nvrtc_library()
  66. # NVRTC constants
  67. NVRTC_SUCCESS = 0
  68. # Helper: check NVRTC errors
  69. def check_nvrtc(result: int) -> None:
  70. if result != NVRTC_SUCCESS:
  71. err_str = ctypes.c_char_p()
  72. libnvrtc.nvrtcGetErrorString(result, ctypes.byref(err_str))
  73. error_message = (
  74. err_str.value.decode()
  75. if err_str.value is not None
  76. else "Unknown CUDA error"
  77. )
  78. raise RuntimeError(f"CUDA error: {error_message}")
  79. # Add 'extern "C"' if not already present to ensure C linkage
  80. if not kernel_source.strip().startswith('extern "C"'):
  81. kernel_source = f'extern "C" {kernel_source}'
  82. # Combine header code and kernel source
  83. if header_code:
  84. full_source = header_code + "\n" + kernel_source
  85. else:
  86. full_source = kernel_source
  87. # Convert source to bytes
  88. source_bytes = full_source.encode("utf-8")
  89. # Get compute capability if not provided
  90. if compute_capability is None:
  91. props = torch.cuda.get_device_properties(torch.cuda.current_device())
  92. compute_capability = f"{props.major}{props.minor}"
  93. # Prepare compilation options
  94. options = []
  95. options.append(f"--gpu-architecture=sm_{compute_capability}".encode())
  96. # Add custom include directories
  97. if cuda_include_dirs:
  98. for directory in cuda_include_dirs:
  99. options.append(f"-I{directory}".encode())
  100. # Add custom NVCC options
  101. if nvcc_options:
  102. for option in nvcc_options:
  103. options.append(option.encode("utf-8"))
  104. # TODO: Should we refactor flags into a common place?
  105. from torch.utils.cpp_extension import COMMON_NVCC_FLAGS
  106. # Filter out flags not supported by NVRTC
  107. nvrtc_compatible_flags = [
  108. flag for flag in COMMON_NVCC_FLAGS if flag != "--expt-relaxed-constexpr"
  109. ]
  110. options.extend([flag.encode("utf-8") for flag in nvrtc_compatible_flags])
  111. # Convert options to C array
  112. num_options = len(options)
  113. options_array = (ctypes.c_char_p * num_options)(*options)
  114. # Create program
  115. prog = ctypes.c_void_p()
  116. check_nvrtc(
  117. libnvrtc.nvrtcCreateProgram(
  118. ctypes.byref(prog),
  119. source_bytes,
  120. f"{kernel_name}.cu".encode(),
  121. 0,
  122. None,
  123. None,
  124. )
  125. )
  126. # Compile program
  127. res = libnvrtc.nvrtcCompileProgram(prog, num_options, options_array)
  128. # Handle compilation errors
  129. if res != NVRTC_SUCCESS:
  130. # Get log
  131. log_size = ctypes.c_size_t()
  132. libnvrtc.nvrtcGetProgramLogSize(prog, ctypes.byref(log_size))
  133. log = ctypes.create_string_buffer(log_size.value)
  134. libnvrtc.nvrtcGetProgramLog(prog, log)
  135. raise RuntimeError(f"Kernel compilation failed:\n{log.value.decode()}")
  136. # Get PTX
  137. ptx_size = ctypes.c_size_t()
  138. check_nvrtc(libnvrtc.nvrtcGetPTXSize(prog, ctypes.byref(ptx_size)))
  139. ptx = ctypes.create_string_buffer(ptx_size.value)
  140. check_nvrtc(libnvrtc.nvrtcGetPTX(prog, ptx))
  141. libnvrtc.nvrtcDestroyProgram(ctypes.byref(prog))
  142. return ptx.value
  143. class _CudaModule:
  144. def __init__(self, module: ctypes.c_void_p) -> None:
  145. self._module = module
  146. self._kernels: dict[str, _CudaKernel] = {}
  147. def __getattr__(self, name: str) -> "_CudaKernel":
  148. if name in self._kernels:
  149. return self._kernels[name]
  150. # Import the CUDA library inside the method
  151. from torch.cuda._utils import _get_cuda_library
  152. libcuda = _get_cuda_library()
  153. func = ctypes.c_void_p()
  154. try:
  155. _check_cuda(
  156. libcuda.cuModuleGetFunction(
  157. ctypes.byref(func), self._module, name.encode("utf-8")
  158. )
  159. )
  160. kernel = _CudaKernel(func, self._module)
  161. self._kernels[name] = kernel
  162. return kernel
  163. except RuntimeError as err:
  164. raise AttributeError(f"No kernel named '{name}' in this module") from err
  165. class _CudaKernel:
  166. """
  167. Represents a compiled CUDA kernel that can be called with PyTorch tensors.
  168. """
  169. def __init__(self, func: ctypes.c_void_p, module: ctypes.c_void_p) -> None:
  170. self.func = func
  171. self.module = module
  172. def __call__(
  173. self,
  174. grid: tuple[int, int, int] = (1, 1, 1),
  175. block: tuple[int, int, int] = (1, 1, 1),
  176. args: Optional[list] = None,
  177. shared_mem: int = 0,
  178. stream: Optional[Any] = None,
  179. ) -> None:
  180. """
  181. Call the compiled CUDA kernel
  182. Args:
  183. grid (tuple): Grid dimensions (grid_x, grid_y, grid_z)
  184. block (tuple): Block dimensions (block_x, block_y, block_z)
  185. args (list): List of arguments to pass to the kernel.
  186. PyTorch tensor arguments will be automatically converted to pointers.
  187. shared_mem (int): Shared memory size in bytes
  188. stream (torch.cuda.Stream): CUDA stream to use. If None, uses current stream.
  189. """
  190. import torch
  191. libcuda = torch.cuda._utils._get_cuda_library()
  192. if not args:
  193. args = []
  194. # Process arguments and convert tensors to pointers
  195. processed_args: list[ctypes.c_void_p] = []
  196. c_args = []
  197. for arg in args:
  198. if isinstance(arg, torch.Tensor):
  199. if not arg.is_cuda and not (arg.is_cpu and arg.is_pinned()):
  200. raise ValueError(
  201. "All tensor arguments must be CUDA tensors or pinned CPU tensors"
  202. )
  203. # Get pointer to tensor data
  204. ptr = ctypes.c_void_p(arg.data_ptr())
  205. processed_args.append(ptr)
  206. c_args.append(ctypes.byref(ptr))
  207. elif isinstance(arg, int):
  208. # Convert integers to C int
  209. c_int = ctypes.c_int(arg)
  210. # Store the C int for reference keeping, not in processed_args
  211. c_args.append(ctypes.byref(c_int))
  212. # TODO: Python floats are actually doubles
  213. elif isinstance(arg, float):
  214. # Convert floats to C float
  215. c_float = ctypes.c_float(arg)
  216. # Store the C float for reference keeping, not in processed_args
  217. c_args.append(ctypes.byref(c_float))
  218. else:
  219. raise TypeError(f"Unsupported argument type: {type(arg)}")
  220. # Convert to array of void pointers
  221. c_args_array = (ctypes.c_void_p * len(c_args))()
  222. for i, arg in enumerate(c_args):
  223. c_args_array[i] = ctypes.cast(arg, ctypes.c_void_p)
  224. # Get the stream
  225. if stream is None:
  226. # Defer import to avoid circular imports
  227. import torch.cuda
  228. stream = torch.cuda.current_stream()
  229. _check_cuda(
  230. libcuda.cuLaunchKernel(
  231. self.func,
  232. grid[0],
  233. grid[1],
  234. grid[2],
  235. block[0],
  236. block[1],
  237. block[2],
  238. shared_mem,
  239. stream._as_parameter_,
  240. c_args_array,
  241. None,
  242. )
  243. )
  244. def _cuda_load_module(
  245. ptx: Union[str, bytes], kernel_names: Optional[list[str]] = None
  246. ) -> Union[_CudaModule, dict[str, "_CudaKernel"]]:
  247. """
  248. Loads a CUDA module from PTX code and returns a module object that can access kernels.
  249. Args:
  250. ptx (bytes or str): The PTX code to load
  251. kernel_names (list, optional): List of kernel names to extract from the module.
  252. If None, will return a module object with __getattr__.
  253. Returns:
  254. object: If kernel_names is None, returns a module object with __getattr__ to access kernels.
  255. If kernel_names is provided, returns a dict mapping kernel names to _CudaKernel objects.
  256. """
  257. # Ensure CUDA is initialized
  258. import torch.cuda
  259. # Load CUDA driver library
  260. libcuda = _get_cuda_library()
  261. # Convert PTX to bytes if it's a string
  262. if isinstance(ptx, str):
  263. ptx = ptx.encode("utf-8")
  264. # Load PTX module
  265. module = ctypes.c_void_p()
  266. # Get the current stream without directly importing torch.cuda at module level
  267. stream = torch.cuda.current_stream()
  268. with stream:
  269. _check_cuda(libcuda.cuModuleLoadData(ctypes.byref(module), ptx))
  270. if not kernel_names:
  271. return _CudaModule(module)
  272. # Return specific kernels
  273. kernels = {}
  274. for name in kernel_names:
  275. func = ctypes.c_void_p()
  276. _check_cuda(
  277. libcuda.cuModuleGetFunction(
  278. ctypes.byref(func), module, name.encode("utf-8")
  279. )
  280. )
  281. kernels[name] = _CudaKernel(func, module)
  282. return kernels
  283. def _get_device_index(
  284. device: Any, optional: bool = False, allow_cpu: bool = False
  285. ) -> int:
  286. r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``.
  287. If :attr:`device` is a torch.device object, returns the device index if it
  288. is a CUDA device. Note that for a CUDA device without a specified index,
  289. i.e., ``torch.device('cuda')``, this will return the current default CUDA
  290. device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
  291. CPU devices will be accepted and ``-1`` will be returned in this case.
  292. If :attr:`device` is a Python integer, it is returned as is.
  293. If :attr:`device` is ``None``, this will return the current default CUDA
  294. device if :attr:`optional` is ``True``.
  295. """
  296. if isinstance(device, int):
  297. return device
  298. if isinstance(device, str):
  299. device = torch.device(device)
  300. if isinstance(device, torch.device):
  301. if allow_cpu:
  302. if device.type not in ["cuda", "cpu"]:
  303. raise ValueError(f"Expected a cuda or cpu device, but got: {device}")
  304. elif device.type != "cuda":
  305. raise ValueError(f"Expected a cuda device, but got: {device}")
  306. if not torch.jit.is_scripting():
  307. if isinstance(device, torch.cuda.device):
  308. return device.idx
  309. return _torch_get_device_index(device, optional, allow_cpu)